diff --git a/.changes/next-release/feature-AmazonS3-a20b910.json b/.changes/next-release/feature-AmazonS3-a20b910.json new file mode 100644 index 000000000000..2ca22a256608 --- /dev/null +++ b/.changes/next-release/feature-AmazonS3-a20b910.json @@ -0,0 +1,6 @@ +{ + "category": "Amazon S3", + "contributor": "", + "type": "feature", + "description": "Add support for pause/resume upload for TransferManager with Java-based S3Client that has multipart enabled" +} diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java index 98f81b731014..eb45d1f7370d 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3IntegrationTestBase.java @@ -68,6 +68,7 @@ public static void setUpForAllIntegTests() throws Exception { Log.initLoggingToStdout(Log.LogLevel.Warn); System.setProperty("aws.crt.debugnative", "true"); s3 = s3ClientBuilder().build(); + // TODO - enable multipart once TransferListener fixed for MultipartClient s3Async = s3AsyncClientBuilder().build(); s3CrtAsync = S3CrtAsyncClient.builder() .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) diff --git a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java index 23705c6bc5bf..e872080f2e32 100644 --- a/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java +++ b/services-custom/s3-transfer-manager/src/it/java/software/amazon/awssdk/transfer/s3/S3TransferManagerUploadPauseResumeIntegrationTest.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.transfer.s3; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; @@ -26,13 +25,17 @@ import java.time.Duration; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.core.retry.backoff.FixedDelayBackoffStrategy; import software.amazon.awssdk.core.waiters.AsyncWaiter; import software.amazon.awssdk.core.waiters.Waiter; import software.amazon.awssdk.core.waiters.WaiterAcceptor; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.ListMultipartUploadsResponse; import software.amazon.awssdk.services.s3.model.ListPartsResponse; import software.amazon.awssdk.services.s3.model.NoSuchUploadException; @@ -48,17 +51,25 @@ public class S3TransferManagerUploadPauseResumeIntegrationTest extends S3Integra private static final String BUCKET = temporaryBucketName(S3TransferManagerUploadPauseResumeIntegrationTest.class); private static final String KEY = "key"; // 24 * MB is chosen to make sure we have data written in the file already upon pausing. - private static final long OBJ_SIZE = 24 * MB; + private static final long LARGE_OBJ_SIZE = 24 * MB; + private static final long SMALL_OBJ_SIZE = 2 * MB; private static File largeFile; private static File smallFile; private static ScheduledExecutorService executorService; + // TODO - switch to tmJava from TestBase once TransferListener fixed for MultipartClient + protected static S3TransferManager tmJavaMpu; + @BeforeAll public static void setup() throws Exception { createBucket(BUCKET); - largeFile = new RandomTempFile(OBJ_SIZE); - smallFile = new RandomTempFile(2 * MB); + largeFile = new RandomTempFile(LARGE_OBJ_SIZE); + smallFile = new RandomTempFile(SMALL_OBJ_SIZE); executorService = Executors.newScheduledThreadPool(3); + + // TODO - switch to tmJava from TestBase once TransferListener fixed for MultipartClient + S3AsyncClient s3AsyncMpu = s3AsyncClientBuilder().multipartEnabled(true).build(); + tmJavaMpu = S3TransferManager.builder().s3Client(s3AsyncMpu).build(); } @AfterAll @@ -69,30 +80,42 @@ public static void cleanup() { executorService.shutdown(); } - @Test - void pause_singlePart_shouldResume() { + private static Stream transferManagers() { + return Stream.of( + Arguments.of(tmJavaMpu, tmJavaMpu), + Arguments.of(tmCrt, tmCrt), + Arguments.of(tmCrt, tmJavaMpu), + Arguments.of(tmJavaMpu, tmCrt) + ); + } + + @ParameterizedTest + @MethodSource("transferManagers") + void pause_singlePart_shouldResume(S3TransferManager uploadTm, S3TransferManager resumeTm) { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .source(smallFile) .build(); - FileUpload fileUpload = tmCrt.uploadFile(request); + FileUpload fileUpload = uploadTm.uploadFile(request); ResumableFileUpload resumableFileUpload = fileUpload.pause(); log.debug(() -> "Paused: " + resumableFileUpload); validateEmptyResumeToken(resumableFileUpload); - FileUpload resumedUpload = tmCrt.resumeUploadFile(resumableFileUpload); + FileUpload resumedUpload = resumeTm.resumeUploadFile(resumableFileUpload); resumedUpload.completionFuture().join(); + assertThat(resumedUpload.progress().snapshot().totalBytes()).hasValue(SMALL_OBJ_SIZE); } - @Test - void pause_fileNotChanged_shouldResume() { + @ParameterizedTest + @MethodSource("transferManagers") + void pause_fileNotChanged_shouldResume(S3TransferManager uploadTm, S3TransferManager resumeTm) throws Exception { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .addTransferListener(LoggingTransferListener.create()) .source(largeFile) .build(); - FileUpload fileUpload = tmCrt.uploadFile(request); + FileUpload fileUpload = uploadTm.uploadFile(request); waitUntilMultipartUploadExists(); ResumableFileUpload resumableFileUpload = fileUpload.pause(); log.debug(() -> "Paused: " + resumableFileUpload); @@ -103,33 +126,37 @@ void pause_fileNotChanged_shouldResume() { verifyMultipartUploadIdExists(resumableFileUpload); - FileUpload resumedUpload = tmCrt.resumeUploadFile(resumableFileUpload); + FileUpload resumedUpload = resumeTm.resumeUploadFile(resumableFileUpload); resumedUpload.completionFuture().join(); + assertThat(resumedUpload.progress().snapshot().totalBytes()).hasValue(LARGE_OBJ_SIZE); } - @Test - void pauseImmediately_resume_shouldStartFromBeginning() { + @ParameterizedTest + @MethodSource("transferManagers") + void pauseImmediately_resume_shouldStartFromBeginning(S3TransferManager uploadTm, S3TransferManager resumeTm) { UploadFileRequest request = UploadFileRequest.builder() - .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) - .source(largeFile) - .build(); - FileUpload fileUpload = tmCrt.uploadFile(request); + .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) + .source(largeFile) + .build(); + FileUpload fileUpload = uploadTm.uploadFile(request); ResumableFileUpload resumableFileUpload = fileUpload.pause(); log.debug(() -> "Paused: " + resumableFileUpload); validateEmptyResumeToken(resumableFileUpload); - FileUpload resumedUpload = tmCrt.resumeUploadFile(resumableFileUpload); + FileUpload resumedUpload = resumeTm.resumeUploadFile(resumableFileUpload); resumedUpload.completionFuture().join(); + assertThat(resumedUpload.progress().snapshot().totalBytes()).hasValue(LARGE_OBJ_SIZE); } - @Test - void pause_fileChanged_resumeShouldStartFromBeginning() throws Exception { + @ParameterizedTest + @MethodSource("transferManagers") + void pause_fileChanged_resumeShouldStartFromBeginning(S3TransferManager uploadTm, S3TransferManager resumeTm) throws Exception { UploadFileRequest request = UploadFileRequest.builder() .putObjectRequest(b -> b.bucket(BUCKET).key(KEY)) .source(largeFile) .build(); - FileUpload fileUpload = tmCrt.uploadFile(request); + FileUpload fileUpload = uploadTm.uploadFile(request); waitUntilMultipartUploadExists(); ResumableFileUpload resumableFileUpload = fileUpload.pause(); log.debug(() -> "Paused: " + resumableFileUpload); @@ -139,13 +166,18 @@ void pause_fileChanged_resumeShouldStartFromBeginning() throws Exception { assertThat(resumableFileUpload.totalParts()).isNotEmpty(); verifyMultipartUploadIdExists(resumableFileUpload); - byte[] bytes = "helloworld".getBytes(StandardCharsets.UTF_8); - Files.write(largeFile.toPath(), bytes); - - FileUpload resumedUpload = tmCrt.resumeUploadFile(resumableFileUpload); - resumedUpload.completionFuture().join(); - verifyMultipartUploadIdNotExist(resumableFileUpload); - assertThat(resumedUpload.progress().snapshot().totalBytes()).hasValue(bytes.length); + byte[] originalBytes = Files.readAllBytes(largeFile.toPath()); + try { + byte[] bytes = "helloworld".getBytes(StandardCharsets.UTF_8); + Files.write(largeFile.toPath(), bytes); + + FileUpload resumedUpload = resumeTm.resumeUploadFile(resumableFileUpload); + resumedUpload.completionFuture().join(); + verifyMultipartUploadIdNotExist(resumableFileUpload); + assertThat(resumedUpload.progress().snapshot().totalBytes()).hasValue(bytes.length); + } finally { + Files.write(largeFile.toPath(), originalBytes); + } } private void verifyMultipartUploadIdExists(ResumableFileUpload resumableFileUpload) { diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/CrtS3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/CrtS3TransferManager.java index eef9205be1c7..71ebeef56e62 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/CrtS3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/CrtS3TransferManager.java @@ -20,7 +20,6 @@ import static software.amazon.awssdk.services.s3.crt.S3CrtSdkHttpExecutionAttribute.METAREQUEST_PAUSE_OBSERVABLE; import static software.amazon.awssdk.services.s3.internal.crt.S3InternalSdkHttpExecutionAttribute.CRT_PAUSE_RESUME_TOKEN; import static software.amazon.awssdk.transfer.s3.internal.GenericS3TransferManager.assertNotUnsupportedArn; -import static software.amazon.awssdk.transfer.s3.internal.utils.FileUtils.fileNotModified; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; @@ -31,7 +30,6 @@ import software.amazon.awssdk.http.SdkHttpExecutionAttributes; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestPauseObservable; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.transfer.s3.S3TransferManager; @@ -51,6 +49,7 @@ @SdkInternalApi class CrtS3TransferManager extends DelegatingS3TransferManager { private static final Logger log = Logger.loggerFor(S3TransferManager.class); + private static final PauseResumeHelper PAUSE_RESUME_HELPER = new PauseResumeHelper(); private final S3AsyncClient s3AsyncClient; CrtS3TransferManager(TransferManagerConfiguration transferConfiguration, S3AsyncClient s3AsyncClient, @@ -99,67 +98,15 @@ public FileUpload uploadFile(UploadFileRequest uploadFileRequest) { return new CrtFileUpload(returnFuture, progressUpdater.progress(), observable, uploadFileRequest); } - private FileUpload uploadFromBeginning(ResumableFileUpload resumableFileUpload, boolean fileModified, - boolean noResumeToken) { - UploadFileRequest uploadFileRequest = resumableFileUpload.uploadFileRequest(); - PutObjectRequest putObjectRequest = uploadFileRequest.putObjectRequest(); - if (fileModified) { - log.debug(() -> String.format("The file (%s) has been modified since " - + "the last pause. " + - "The SDK will upload the requested object in bucket" - + " (%s) with key (%s) from " - + "the " - + "beginning.", - uploadFileRequest.source(), - putObjectRequest.bucket(), - putObjectRequest.key())); - resumableFileUpload.multipartUploadId() - .ifPresent(id -> { - log.debug(() -> "Aborting previous upload with multipartUploadId: " + id); - s3AsyncClient.abortMultipartUpload( - AbortMultipartUploadRequest.builder() - .bucket(putObjectRequest.bucket()) - .key(putObjectRequest.key()) - .uploadId(id) - .build()) - .exceptionally(t -> { - log.warn(() -> String.format("Failed to abort previous multipart upload " - + "(id: %s)" - + ". You may need to call " - + "S3AsyncClient#abortMultiPartUpload to " - + "free all storage consumed by" - + " all parts. ", - id), t); - return null; - }); - }); - } - - if (noResumeToken) { - log.debug(() -> String.format("No resume token is found. " + - "The SDK will upload the requested object in bucket" - + " (%s) with key (%s) from " - + "the beginning.", - putObjectRequest.bucket(), - putObjectRequest.key())); - } - - - return uploadFile(uploadFileRequest); - } - @Override public FileUpload resumeUploadFile(ResumableFileUpload resumableFileUpload) { Validate.paramNotNull(resumableFileUpload, "resumableFileUpload"); - boolean fileModified = !fileNotModified(resumableFileUpload.fileLength(), - resumableFileUpload.fileLastModified(), - resumableFileUpload.uploadFileRequest().source()); - - boolean noResumeToken = !hasResumeToken(resumableFileUpload); + boolean fileModified = PAUSE_RESUME_HELPER.fileModified(resumableFileUpload, s3AsyncClient); + boolean noResumeToken = !PAUSE_RESUME_HELPER.hasResumeToken(resumableFileUpload); if (fileModified || noResumeToken) { - return uploadFromBeginning(resumableFileUpload, fileModified, noResumeToken); + return uploadFile(resumableFileUpload.uploadFileRequest()); } return doResumeUpload(resumableFileUpload); @@ -188,10 +135,6 @@ private static ResumeToken crtResumeToken(ResumableFileUpload resumableFileUploa .withUploadId(resumableFileUpload.multipartUploadId().orElse(null))); } - private boolean hasResumeToken(ResumableFileUpload resumableFileUpload) { - return resumableFileUpload.totalParts().isPresent() && resumableFileUpload.partSizeInBytes().isPresent(); - } - private PutObjectRequest attachSdkAttribute(PutObjectRequest putObjectRequest, Consumer builderMutation) { SdkHttpExecutionAttributes modifiedAttributes = diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java index b06d0824b709..ec9f25a133c1 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java @@ -15,14 +15,18 @@ package software.amazon.awssdk.transfer.s3.internal; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.PAUSE_OBSERVABLE; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.RESUME_TOKEN; import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; import static software.amazon.awssdk.transfer.s3.internal.utils.ResumableRequestConverter.toDownloadFileRequestAndTransformer; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.arns.Arn; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.FileTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; @@ -30,6 +34,7 @@ import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; import software.amazon.awssdk.services.s3.internal.resource.S3AccessPointResource; import software.amazon.awssdk.services.s3.internal.resource.S3ArnConverter; import software.amazon.awssdk.services.s3.internal.resource.S3Resource; @@ -39,6 +44,8 @@ import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.multipart.PauseObservable; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; import software.amazon.awssdk.transfer.s3.S3TransferManager; import software.amazon.awssdk.transfer.s3.internal.model.DefaultCopy; import software.amazon.awssdk.transfer.s3.internal.model.DefaultDirectoryDownload; @@ -65,6 +72,7 @@ import software.amazon.awssdk.transfer.s3.model.FileDownload; import software.amazon.awssdk.transfer.s3.model.FileUpload; import software.amazon.awssdk.transfer.s3.model.ResumableFileDownload; +import software.amazon.awssdk.transfer.s3.model.ResumableFileUpload; import software.amazon.awssdk.transfer.s3.model.Upload; import software.amazon.awssdk.transfer.s3.model.UploadDirectoryRequest; import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; @@ -80,6 +88,7 @@ class GenericS3TransferManager implements S3TransferManager { protected static final int DEFAULT_FILE_UPLOAD_CHUNK_SIZE = (int) (16 * MB); private static final Logger log = Logger.loggerFor(S3TransferManager.class); + private static final PauseResumeHelper PAUSE_RESUME_HELPER = new PauseResumeHelper(); private final S3AsyncClient s3AsyncClient; private final UploadDirectoryHelper uploadDirectoryHelper; private final DownloadDirectoryHelper downloadDirectoryHelper; @@ -157,6 +166,15 @@ public FileUpload uploadFile(UploadFileRequest uploadFileRequest) { .build(); PutObjectRequest putObjectRequest = uploadFileRequest.putObjectRequest(); + PauseObservable pauseObservable; + if (isS3ClientMultipartEnabled()) { + pauseObservable = new PauseObservable(); + Consumer attachPauseObservable = + b -> b.putExecutionAttribute(PAUSE_OBSERVABLE, pauseObservable); + putObjectRequest = attachSdkAttribute(uploadFileRequest.putObjectRequest(), attachPauseObservable); + } else { + pauseObservable = null; + } CompletableFuture returnFuture = new CompletableFuture<>(); @@ -182,8 +200,72 @@ public FileUpload uploadFile(UploadFileRequest uploadFileRequest) { } catch (Throwable throwable) { returnFuture.completeExceptionally(throwable); } + return new DefaultFileUpload(returnFuture, progressUpdater.progress(), pauseObservable, uploadFileRequest); + } + + @Override + public FileUpload resumeUploadFile(ResumableFileUpload resumableFileUpload) { + Validate.paramNotNull(resumableFileUpload, "resumableFileUpload"); + + boolean fileModified = PAUSE_RESUME_HELPER.fileModified(resumableFileUpload, s3AsyncClient); + boolean noResumeToken = !PAUSE_RESUME_HELPER.hasResumeToken(resumableFileUpload); + + if (fileModified || noResumeToken) { + return uploadFile(resumableFileUpload.uploadFileRequest()); + } + + return doResumeUpload(resumableFileUpload); + } + + private boolean isS3ClientMultipartEnabled() { + // TODO use configuration getter when available + return s3AsyncClient instanceof MultipartS3AsyncClient; + } + + private FileUpload doResumeUpload(ResumableFileUpload resumableFileUpload) { + UploadFileRequest uploadFileRequest = resumableFileUpload.uploadFileRequest(); + PutObjectRequest putObjectRequest = uploadFileRequest.putObjectRequest(); + S3ResumeToken s3ResumeToken = s3ResumeToken(resumableFileUpload); + + Consumer attachResumeToken = + b -> b.putExecutionAttribute(RESUME_TOKEN, s3ResumeToken); + + PutObjectRequest modifiedPutObjectRequest = attachSdkAttribute(putObjectRequest, attachResumeToken); + + return uploadFile(uploadFileRequest.toBuilder() + .putObjectRequest(modifiedPutObjectRequest) + .build()); + } + + private static S3ResumeToken s3ResumeToken(ResumableFileUpload resumableFileUpload) { + S3ResumeToken.Builder builder = S3ResumeToken.builder(); + + builder.uploadId(resumableFileUpload.multipartUploadId().orElse(null)); + if (resumableFileUpload.partSizeInBytes().isPresent()) { + builder.partSize(resumableFileUpload.partSizeInBytes().getAsLong()); + } + if (resumableFileUpload.totalParts().isPresent()) { + builder.totalNumParts(resumableFileUpload.totalParts().getAsLong()); + } + if (resumableFileUpload.transferredParts().isPresent()) { + builder.numPartsCompleted(resumableFileUpload.transferredParts().getAsLong()); + } + + return builder.build(); + } - return new DefaultFileUpload(returnFuture, progressUpdater.progress(), uploadFileRequest); + private PutObjectRequest attachSdkAttribute(PutObjectRequest putObjectRequest, + Consumer builderMutation) { + AwsRequestOverrideConfiguration modifiedRequestOverrideConfig = + putObjectRequest.overrideConfiguration() + .map(o -> o.toBuilder().applyMutation(builderMutation).build()) + .orElseGet(() -> AwsRequestOverrideConfiguration.builder() + .applyMutation(builderMutation) + .build()); + + return putObjectRequest.toBuilder() + .overrideConfiguration(modifiedRequestOverrideConfig) + .build(); } @Override diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/PauseResumeHelper.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/PauseResumeHelper.java new file mode 100644 index 000000000000..9c5220f388e4 --- /dev/null +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/PauseResumeHelper.java @@ -0,0 +1,91 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.transfer.s3.internal; + +import static software.amazon.awssdk.transfer.s3.internal.utils.FileUtils.fileNotModified; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.transfer.s3.model.ResumableFileUpload; +import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; +import software.amazon.awssdk.utils.Logger; + +@SdkInternalApi +public class PauseResumeHelper { + private static final Logger log = Logger.loggerFor(PauseResumeHelper.class); + + protected boolean fileModified(ResumableFileUpload resumableFileUpload, S3AsyncClient s3AsyncClient) { + boolean fileModified = !fileNotModified(resumableFileUpload.fileLength(), + resumableFileUpload.fileLastModified(), + resumableFileUpload.uploadFileRequest().source()); + + if (fileModified) { + UploadFileRequest uploadFileRequest = resumableFileUpload.uploadFileRequest(); + PutObjectRequest putObjectRequest = uploadFileRequest.putObjectRequest(); + log.debug(() -> String.format("The file (%s) has been modified since " + + "the last pause. " + + "The SDK will upload the requested object in bucket" + + " (%s) with key (%s) from " + + "the " + + "beginning.", + uploadFileRequest.source(), + putObjectRequest.bucket(), + putObjectRequest.key())); + resumableFileUpload.multipartUploadId() + .ifPresent(id -> { + log.debug(() -> "Aborting previous upload with multipartUploadId: " + id); + s3AsyncClient.abortMultipartUpload( + AbortMultipartUploadRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .uploadId(id) + .build()) + .exceptionally(t -> { + log.warn(() -> String.format("Failed to abort previous multipart upload " + + "(id: %s)" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + id), t); + return null; + }); + }); + } + + return fileModified; + } + + protected boolean hasResumeToken(ResumableFileUpload resumableFileUpload) { + boolean hasResumeToken = + resumableFileUpload.totalParts().isPresent() && resumableFileUpload.partSizeInBytes().isPresent(); + + if (!hasResumeToken) { + UploadFileRequest uploadFileRequest = resumableFileUpload.uploadFileRequest(); + PutObjectRequest putObjectRequest = uploadFileRequest.putObjectRequest(); + log.debug(() -> String.format("No resume token is found. " + + "The SDK will upload the requested object in bucket" + + " (%s) with key (%s) from " + + "the beginning.", + putObjectRequest.bucket(), + putObjectRequest.key())); + } + + return hasResumeToken; + } +} diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/CrtFileUpload.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/CrtFileUpload.java index 4f7a4a757c2c..790fb0d2ba60 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/CrtFileUpload.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/CrtFileUpload.java @@ -148,7 +148,7 @@ public int hashCode() { @Override public String toString() { - return ToString.builder("DefaultFileUpload") + return ToString.builder("CrtFileUpload") .add("completionFuture", completionFuture) .add("progress", progress) .add("request", request) diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileUpload.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileUpload.java index 1579c64dbdf1..66647d27cd61 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileUpload.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/model/DefaultFileUpload.java @@ -15,35 +15,77 @@ package software.amazon.awssdk.transfer.s3.internal.model; +import java.io.File; +import java.time.Instant; import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.services.s3.multipart.PauseObservable; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; import software.amazon.awssdk.transfer.s3.model.CompletedFileUpload; import software.amazon.awssdk.transfer.s3.model.FileUpload; import software.amazon.awssdk.transfer.s3.model.ResumableFileUpload; import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; import software.amazon.awssdk.transfer.s3.progress.TransferProgress; +import software.amazon.awssdk.utils.Lazy; import software.amazon.awssdk.utils.ToString; import software.amazon.awssdk.utils.Validate; @SdkInternalApi public final class DefaultFileUpload implements FileUpload { + private final Lazy resumableFileUpload; private final CompletableFuture completionFuture; private final TransferProgress progress; private final UploadFileRequest request; + private final PauseObservable pauseObservable; public DefaultFileUpload(CompletableFuture completionFuture, TransferProgress progress, + PauseObservable pauseObservable, UploadFileRequest request) { this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture"); this.progress = Validate.paramNotNull(progress, "progress"); this.request = Validate.paramNotNull(request, "request"); + this.pauseObservable = pauseObservable; + this.resumableFileUpload = new Lazy<>(this::doPause); } @Override public ResumableFileUpload pause() { - throw new UnsupportedOperationException("Pausing an upload is not supported in a non CRT-based S3 Client. For " - + "upload pause support, pass an AWS CRT-based S3 client to S3TransferManager" - + "instead: S3AsyncClient.crtBuilder().build();"); + if (pauseObservable == null) { + throw new UnsupportedOperationException("Pausing an upload is not supported in a non CRT-based S3Client that does " + + "not have multipart configuration enabled. For upload pause support, pass " + + "a CRT-based S3Client or an S3Client with multipart enabled to " + + "S3TransferManager."); + } + + return resumableFileUpload.getValue(); + } + + private ResumableFileUpload doPause() { + File sourceFile = request.source().toFile(); + Instant fileLastModified = Instant.ofEpochMilli(sourceFile.lastModified()); + + ResumableFileUpload.Builder resumableFileBuilder = ResumableFileUpload.builder() + .fileLastModified(fileLastModified) + .fileLength(sourceFile.length()) + .uploadFileRequest(request); + + if (completionFuture.isDone()) { + return resumableFileBuilder.build(); + } + + S3ResumeToken token = pauseObservable.pause(); + + // Upload hasn't started yet, or it's a single object upload + if (token == null) { + return resumableFileBuilder.build(); + } + + return resumableFileBuilder.multipartUploadId(token.uploadId()) + .totalParts(token.totalNumParts()) + .transferredParts(token.numPartsCompleted()) + .partSizeInBytes(token.partSize()) + .build(); } @Override @@ -67,20 +109,28 @@ public boolean equals(Object o) { DefaultFileUpload that = (DefaultFileUpload) o; + if (!resumableFileUpload.equals(that.resumableFileUpload)) { + return false; + } if (!completionFuture.equals(that.completionFuture)) { return false; } if (!progress.equals(that.progress)) { return false; } - return request.equals(that.request); + if (!request.equals(that.request)) { + return false; + } + return pauseObservable == that.pauseObservable; } @Override public int hashCode() { - int result = completionFuture.hashCode(); + int result = resumableFileUpload.hashCode(); + result = 31 * result + completionFuture.hashCode(); result = 31 * result + progress.hashCode(); result = 31 * result + request.hashCode(); + result = 31 * result + pauseObservable.hashCode(); return result; } diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/FileUpload.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/FileUpload.java index 28486d76bd0b..90e99f2829f0 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/FileUpload.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/model/FileUpload.java @@ -18,6 +18,8 @@ import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; import software.amazon.awssdk.transfer.s3.S3TransferManager; /** @@ -32,12 +34,15 @@ public interface FileUpload extends ObjectTransfer { *

* The information object is serializable for persistent storage until it should be resumed. * See {@link ResumableFileUpload} for supported formats. - * + * *

* Currently, it's only supported if the underlying {@link S3AsyncClient} is CRT-based (created via - * {@link S3AsyncClient#crtBuilder()} or {@link S3AsyncClient#crtCreate()}). + * {@link S3AsyncClient#crtBuilder()} or {@link S3AsyncClient#crtCreate()}), OR the underlying + * {@link S3AsyncClient} has multipart enabled ({@link S3AsyncClientBuilder#multipartConfiguration(MultipartConfiguration)} + * or {@link S3AsyncClientBuilder#multipartEnabled(Boolean)}). * It will throw {@link UnsupportedOperationException} if the {@link S3TransferManager} is created - * with a non CRT-based S3 client (created via {@link S3AsyncClient#builder()}). + * with a non CRT-based S3 client (created via {@link S3AsyncClient#builder()}) and does not have + * multipart configuration enabled. * * @return A {@link ResumableFileUpload} that can be used to resume the upload. */ diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/DefaultFileUploadTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/DefaultFileUploadTest.java index 539433734920..f7e523d9355f 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/DefaultFileUploadTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/DefaultFileUploadTest.java @@ -15,38 +15,179 @@ package software.amazon.awssdk.transfer.s3.internal; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; -import java.nio.file.Paths; +import com.google.common.jimfs.Jimfs; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.time.Instant; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.services.s3.internal.multipart.PausableUpload; +import software.amazon.awssdk.services.s3.multipart.PauseObservable; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.transfer.s3.internal.model.DefaultFileUpload; -import software.amazon.awssdk.transfer.s3.model.FileUpload; +import software.amazon.awssdk.transfer.s3.internal.progress.DefaultTransferProgressSnapshot; +import software.amazon.awssdk.transfer.s3.model.CompletedFileUpload; +import software.amazon.awssdk.transfer.s3.model.ResumableFileUpload; import software.amazon.awssdk.transfer.s3.model.UploadFileRequest; import software.amazon.awssdk.transfer.s3.progress.TransferProgress; class DefaultFileUploadTest { + private static final long TOTAL_PARTS = 10; + private static final long NUM_OF_PARTS_COMPLETED = 5; + private static final long PART_SIZE_IN_BYTES = 8 * MB; + private static final String MULTIPART_UPLOAD_ID = "someId"; + private static FileSystem fileSystem; + private static File file; + private static S3ResumeToken token; + + @BeforeAll + public static void setUp() throws IOException { + fileSystem = Jimfs.newFileSystem(); + file = File.createTempFile("test", UUID.randomUUID().toString()); + Files.write(file.toPath(), RandomStringUtils.random(2000).getBytes(StandardCharsets.UTF_8)); + token = S3ResumeToken.builder() + .uploadId(MULTIPART_UPLOAD_ID) + .totalNumParts(TOTAL_PARTS) + .numPartsCompleted(NUM_OF_PARTS_COMPLETED) + .partSize(PART_SIZE_IN_BYTES) + .build(); + } + + @AfterAll + public static void tearDown() throws IOException { + file.delete(); + } + @Test void equals_hashcode() { EqualsVerifier.forClass(DefaultFileUpload.class) - .withNonnullFields("completionFuture", "progress", "request") + .withNonnullFields("completionFuture", "progress", "request", "resumableFileUpload", "pauseObservable") + .withPrefabValues(PauseObservable.class, new PauseObservable(), new PauseObservable()) .verify(); } + private S3ResumeToken s3ResumeToken(CompletableFuture future) { + if (future.isDone()) { + return null; + } + return token; + } + @Test - void pause_shouldThrowUnsupportedOperation() { - TransferProgress transferProgress = Mockito.mock(TransferProgress.class); - UploadFileRequest request = UploadFileRequest.builder() - .source(Paths.get("test")) - .putObjectRequest(p -> p.key("test").bucket("bucket")) - .build(); - FileUpload fileUpload = new DefaultFileUpload(new CompletableFuture<>(), - transferProgress, - request); - - assertThatThrownBy(() -> fileUpload.pause()).isInstanceOf(UnsupportedOperationException.class); + void pause_futureCompleted_shouldReturnNormally() { + CompletableFuture future = + CompletableFuture.completedFuture(CompletedFileUpload.builder() + .response(PutObjectResponse.builder().build()) + .build()); + TransferProgress transferProgress = mock(TransferProgress.class); + PauseObservable observable = new PauseObservable(); + PausableUpload pausableUpload = mock(PausableUpload.class); + observable.setPausableUpload(pausableUpload); + when(pausableUpload.pause()).thenReturn(s3ResumeToken(future)); + + UploadFileRequest request = uploadFileRequest(); + DefaultFileUpload fileUpload = new DefaultFileUpload(future, transferProgress, observable, request); + + ResumableFileUpload resumableFileUpload = fileUpload.pause(); + + verify(pausableUpload, Mockito.never()).pause(); + assertThat(resumableFileUpload.totalParts()).isEmpty(); + assertThat(resumableFileUpload.partSizeInBytes()).isEmpty(); + assertThat(resumableFileUpload.multipartUploadId()).isEmpty(); + assertThat(resumableFileUpload.fileLength()).isEqualTo(file.length()); + assertThat(resumableFileUpload.uploadFileRequest()).isEqualTo(request); + assertThat(resumableFileUpload.fileLastModified()).isEqualTo(Instant.ofEpochMilli(file.lastModified())); + } + + + @Test + void pauseTwice_shouldReturnTheSame() { + CompletableFuture future = new CompletableFuture<>(); + TransferProgress transferProgress = mock(TransferProgress.class); + PauseObservable observable = new PauseObservable(); + PausableUpload pausableUpload = mock(PausableUpload.class); + observable.setPausableUpload(pausableUpload); + when(pausableUpload.pause()).thenReturn(s3ResumeToken(future)); + + UploadFileRequest request = uploadFileRequest(); + DefaultFileUpload fileUpload = new DefaultFileUpload(future, transferProgress, observable, request); + + ResumableFileUpload resumableFileUpload = fileUpload.pause(); + ResumableFileUpload resumableFileUpload2 = fileUpload.pause(); + + verify(pausableUpload).pause(); + assertThat(resumableFileUpload).isEqualTo(resumableFileUpload2); + } + + @Test + void pause_futureNotComplete_shouldPause() { + CompletableFuture future = new CompletableFuture<>(); + TransferProgress transferProgress = mock(TransferProgress.class); + when(transferProgress.snapshot()).thenReturn(DefaultTransferProgressSnapshot.builder() + .transferredBytes(0L) + .build()); + + + PauseObservable observable = new PauseObservable(); + PausableUpload pausableUpload = mock(PausableUpload.class); + observable.setPausableUpload(pausableUpload); + when(pausableUpload.pause()).thenReturn(s3ResumeToken(future)); + + UploadFileRequest request = uploadFileRequest(); + DefaultFileUpload fileUpload = new DefaultFileUpload(future, transferProgress, observable, request); + + ResumableFileUpload resumableFileUpload = fileUpload.pause(); + + verify(pausableUpload).pause(); + assertThat(resumableFileUpload.totalParts()).hasValue(TOTAL_PARTS); + assertThat(resumableFileUpload.partSizeInBytes()).hasValue(PART_SIZE_IN_BYTES); + assertThat(resumableFileUpload.multipartUploadId()).hasValue(MULTIPART_UPLOAD_ID); + assertThat(resumableFileUpload.transferredParts()).hasValue(NUM_OF_PARTS_COMPLETED); + assertThat(resumableFileUpload.fileLength()).isEqualTo(file.length()); + assertThat(resumableFileUpload.uploadFileRequest()).isEqualTo(request); + assertThat(resumableFileUpload.fileLastModified()).isEqualTo(Instant.ofEpochMilli(file.lastModified())); + } + + @Test + void pause_singlePart_shouldReturnNullResumeToken() { + CompletableFuture future = new CompletableFuture<>(); + TransferProgress transferProgress = mock(TransferProgress.class); + + PauseObservable observable = new PauseObservable(); + observable.setPausableUpload(null); + + UploadFileRequest request = uploadFileRequest(); + DefaultFileUpload fileUpload = new DefaultFileUpload(future, transferProgress, observable, request); + + ResumableFileUpload resumableFileUpload = fileUpload.pause(); + assertThat(resumableFileUpload.totalParts()).isEmpty(); + assertThat(resumableFileUpload.partSizeInBytes()).isEmpty(); + assertThat(resumableFileUpload.multipartUploadId()).isEmpty(); + assertThat(resumableFileUpload.fileLength()).isEqualTo(file.length()); + assertThat(resumableFileUpload.uploadFileRequest()).isEqualTo(request); + assertThat(resumableFileUpload.fileLastModified()).isEqualTo(Instant.ofEpochMilli(file.lastModified())); + } + + private UploadFileRequest uploadFileRequest() { + return UploadFileRequest.builder() + .source(file) + .putObjectRequest(p -> p.key("test").bucket("bucket")) + .build(); } } \ No newline at end of file diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerUploadPauseAndResumeTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerUploadPauseAndResumeTest.java index 351fd03f7495..a79a8b4a5083 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerUploadPauseAndResumeTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/S3TransferManagerUploadPauseAndResumeTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SDK_HTTP_EXECUTION_ATTRIBUTES; import static software.amazon.awssdk.services.s3.internal.crt.S3InternalSdkHttpExecutionAttribute.CRT_PAUSE_RESUME_TOKEN; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.RESUME_TOKEN; import static software.amazon.awssdk.transfer.s3.SizeConstant.MB; import java.io.File; @@ -32,15 +33,20 @@ import java.time.Instant; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.http.SdkHttpExecutionAttributes; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; @@ -52,7 +58,9 @@ class S3TransferManagerUploadPauseAndResumeTest { private S3CrtAsyncClient mockS3Crt; - private S3TransferManager tm; + private S3AsyncClient mockS3; + private S3TransferManager tmCrt; + private S3TransferManager tmJava; private UploadDirectoryHelper uploadDirectoryHelper; private DownloadDirectoryHelper downloadDirectoryHelper; private TransferManagerConfiguration configuration; @@ -62,21 +70,52 @@ class S3TransferManagerUploadPauseAndResumeTest { public void methodSetup() throws IOException { file = RandomTempFile.createTempFile("test", UUID.randomUUID().toString()); Files.write(file.toPath(), RandomStringUtils.randomAlphanumeric(1000).getBytes(StandardCharsets.UTF_8)); - mockS3Crt = mock(S3CrtAsyncClient.class); uploadDirectoryHelper = mock(UploadDirectoryHelper.class); configuration = mock(TransferManagerConfiguration.class); downloadDirectoryHelper = mock(DownloadDirectoryHelper.class); - tm = new CrtS3TransferManager(configuration, mockS3Crt, false); + mockS3Crt = mock(S3CrtAsyncClient.class); + mockS3 = mock(S3AsyncClient.class); + tmCrt = new CrtS3TransferManager(configuration, mockS3Crt, false); + tmJava = new GenericS3TransferManager(mockS3, uploadDirectoryHelper, configuration, downloadDirectoryHelper); } @AfterEach public void methodTeardown() { file.delete(); - tm.close(); + tmCrt.close(); + tmJava.close(); + } + + enum TmType{ + JAVA, CRT + } + + private static Stream transferManagers() { + return Stream.of( + Arguments.of(TmType.JAVA), + Arguments.of(TmType.CRT) + ); } - @Test - void resumeUploadFile_noResumeToken_shouldUploadFromBeginning() { + private S3TransferManager configureTestBehavior(TmType tmType, PutObjectResponse response) { + if (tmType == TmType.JAVA) { + when(mockS3.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(response)); + when(mockS3.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + return tmJava; + } else { + when(mockS3Crt.putObject(any(PutObjectRequest.class), any(Path.class))) + .thenReturn(CompletableFuture.completedFuture(response)); + when(mockS3Crt.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + return tmCrt; + } + } + + @ParameterizedTest + @MethodSource("transferManagers") + void resumeUploadFile_noResumeToken_shouldUploadFromBeginning(TmType tmType) { PutObjectRequest putObjectRequest = putObjectRequest(); PutObjectResponse response = PutObjectResponse.builder().build(); Instant fileLastModified = Instant.ofEpochMilli(file.lastModified()); @@ -87,9 +126,7 @@ void resumeUploadFile_noResumeToken_shouldUploadFromBeginning() { .source(file) .build(); - - when(mockS3Crt.putObject(any(PutObjectRequest.class), any(Path.class))) - .thenReturn(CompletableFuture.completedFuture(response)); + S3TransferManager tm = configureTestBehavior(tmType, response); CompletedFileUpload completedFileUpload = tm.resumeUploadFile(r -> r.fileLength(fileLength) .uploadFileRequest(uploadFileRequest) @@ -97,11 +134,17 @@ void resumeUploadFile_noResumeToken_shouldUploadFromBeginning() { .completionFuture() .join(); assertThat(completedFileUpload.response()).isEqualTo(response); - verifyActualPutObjectRequestNotResumed(); + + if (tmType == TmType.JAVA) { + verifyActualPutObjectRequestNotResumed_tmJava(); + } else { + verifyActualPutObjectRequestNotResumed_tmCrt(); + } } - @Test - void resumeUploadFile_fileModified_shouldAbortExistingAndUploadFromBeginning() { + @ParameterizedTest + @MethodSource("transferManagers") + void resumeUploadFile_fileModified_shouldAbortExistingAndUploadFromBeginning(TmType tmType) { PutObjectRequest putObjectRequest = putObjectRequest(); PutObjectResponse response = PutObjectResponse.builder().build(); Instant fileLastModified = Instant.ofEpochMilli(file.lastModified()); @@ -112,12 +155,7 @@ void resumeUploadFile_fileModified_shouldAbortExistingAndUploadFromBeginning() { .source(file) .build(); - - when(mockS3Crt.putObject(any(PutObjectRequest.class), any(Path.class))) - .thenReturn(CompletableFuture.completedFuture(response)); - - when(mockS3Crt.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + S3TransferManager tm = configureTestBehavior(tmType, response); String multipartId = "someId"; CompletedFileUpload completedFileUpload = tm.resumeUploadFile(r -> r.fileLength(fileLength + 10L) @@ -129,18 +167,29 @@ void resumeUploadFile_fileModified_shouldAbortExistingAndUploadFromBeginning() { .completionFuture() .join(); assertThat(completedFileUpload.response()).isEqualTo(response); - verifyActualPutObjectRequestNotResumed(); + + if (tmType == TmType.JAVA) { + verifyActualPutObjectRequestNotResumed_tmJava(); + } else { + verifyActualPutObjectRequestNotResumed_tmCrt(); + } ArgumentCaptor abortMultipartUploadRequestArgumentCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); - verify(mockS3Crt).abortMultipartUpload(abortMultipartUploadRequestArgumentCaptor.capture()); + + if (tmType == TmType.JAVA) { + verify(mockS3).abortMultipartUpload(abortMultipartUploadRequestArgumentCaptor.capture()); + } else { + verify(mockS3Crt).abortMultipartUpload(abortMultipartUploadRequestArgumentCaptor.capture()); + } AbortMultipartUploadRequest actualRequest = abortMultipartUploadRequestArgumentCaptor.getValue(); assertThat(actualRequest.uploadId()).isEqualTo(multipartId); } - @Test - void resumeUploadFile_hasValidResumeToken_shouldResumeUpload() { + @ParameterizedTest + @MethodSource("transferManagers") + void resumeUploadFile_hasValidResumeToken_shouldResumeUpload(TmType tmType) { PutObjectRequest putObjectRequest = putObjectRequest(); PutObjectResponse response = PutObjectResponse.builder().build(); Instant fileLastModified = Instant.ofEpochMilli(file.lastModified()); @@ -151,10 +200,7 @@ void resumeUploadFile_hasValidResumeToken_shouldResumeUpload() { .source(file) .build(); - - when(mockS3Crt.putObject(any(PutObjectRequest.class), any(Path.class))) - .thenReturn(CompletableFuture.completedFuture(response)); - + S3TransferManager tm = configureTestBehavior(tmType, response); String multipartId = "someId"; long totalParts = 10L; @@ -169,31 +215,65 @@ void resumeUploadFile_hasValidResumeToken_shouldResumeUpload() { .join(); assertThat(completedFileUpload.response()).isEqualTo(response); - ArgumentCaptor putObjectRequestArgumentCaptor = - ArgumentCaptor.forClass(PutObjectRequest.class); - verify(mockS3Crt).putObject(putObjectRequestArgumentCaptor.capture(), any(Path.class)); + if (tmType == TmType.JAVA) { + verifyActualPutObjectRequestResumedAndCorrectTokenReturned_tmJava(multipartId, partSizeInBytes, totalParts); + } else { + verifyActualPutObjectRequestResumedAndCorrectTokenReturned_tmCrt(multipartId, partSizeInBytes, totalParts); + } + } + + private void verifyActualPutObjectRequestNotResumed_tmCrt() { + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + + verify(mockS3Crt).putObject(putObjectRequestArgumentCaptor.capture(), any(Path.class)); PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); AwsRequestOverrideConfiguration awsRequestOverrideConfiguration = actualRequest.overrideConfiguration().get(); SdkHttpExecutionAttributes attribute = awsRequestOverrideConfiguration.executionAttributes().getAttribute(SDK_HTTP_EXECUTION_ATTRIBUTES); - assertThat(attribute.getAttribute(CRT_PAUSE_RESUME_TOKEN)).satisfies(token -> { - assertThat(token.getUploadId()).isEqualTo(multipartId); - assertThat(token.getPartSize()).isEqualTo(partSizeInBytes); - assertThat(token.getTotalNumParts()).isEqualTo(totalParts); - }); + assertThat(attribute.getAttribute(CRT_PAUSE_RESUME_TOKEN)).isNull(); + } + + private void verifyActualPutObjectRequestNotResumed_tmJava() { + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + + verify(mockS3).putObject(putObjectRequestArgumentCaptor.capture(), any(AsyncRequestBody.class)); + PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); + + assertThat(actualRequest.overrideConfiguration()).isEmpty(); + } + + private void verifyActualPutObjectRequestResumedAndCorrectTokenReturned_tmJava(String multipartId, long partSizeInBytes, + long totalParts) { + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + + verify(mockS3).putObject(putObjectRequestArgumentCaptor.capture(), any(AsyncRequestBody.class)); + PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); + AwsRequestOverrideConfiguration awsRequestOverrideConfiguration = actualRequest.overrideConfiguration().get(); + + assertThat(awsRequestOverrideConfiguration.executionAttributes().getAttribute(RESUME_TOKEN)).isNotNull(); + S3ResumeToken s3ResumeToken = awsRequestOverrideConfiguration.executionAttributes().getAttribute(RESUME_TOKEN); + + assertThat(s3ResumeToken.uploadId()).isEqualTo(multipartId); + assertThat(s3ResumeToken.partSize()).isEqualTo(partSizeInBytes); + assertThat(s3ResumeToken.totalNumParts()).isEqualTo(totalParts); } - private void verifyActualPutObjectRequestNotResumed() { - ArgumentCaptor putObjectRequestArgumentCaptor = - ArgumentCaptor.forClass(PutObjectRequest.class); + private void verifyActualPutObjectRequestResumedAndCorrectTokenReturned_tmCrt(String multipartId, long partSizeInBytes, + long totalParts) { + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3Crt).putObject(putObjectRequestArgumentCaptor.capture(), any(Path.class)); PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); AwsRequestOverrideConfiguration awsRequestOverrideConfiguration = actualRequest.overrideConfiguration().get(); SdkHttpExecutionAttributes attribute = awsRequestOverrideConfiguration.executionAttributes().getAttribute(SDK_HTTP_EXECUTION_ATTRIBUTES); - assertThat(attribute.getAttribute(CRT_PAUSE_RESUME_TOKEN)).isNull(); + assertThat(attribute.getAttribute(CRT_PAUSE_RESUME_TOKEN)).satisfies(token -> { + assertThat(token.getUploadId()).isEqualTo(multipartId); + assertThat(token.getPartSize()).isEqualTo(partSizeInBytes); + assertThat(token.getTotalNumParts()).isEqualTo(totalParts); + }); } private static PutObjectRequest putObjectRequest() { diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/UploadDirectoryHelperTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/UploadDirectoryHelperTest.java index aba7cc86ae0d..9e975f09a357 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/UploadDirectoryHelperTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/UploadDirectoryHelperTest.java @@ -48,7 +48,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.services.s3.internal.crt.S3MetaRequestPauseObservable; +import software.amazon.awssdk.services.s3.multipart.PauseObservable; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.testutils.FileUtils; @@ -443,6 +443,7 @@ private DefaultFileUpload completedUpload() { new DefaultTransferProgress(DefaultTransferProgressSnapshot.builder() .transferredBytes(0L) .build()), + new PauseObservable(), UploadFileRequest.builder() .source(Paths.get(".")).putObjectRequest(b -> b.bucket("bucket").key("key")) .build()); @@ -453,6 +454,7 @@ private FileUpload newUpload(CompletableFuture future) { new DefaultTransferProgress(DefaultTransferProgressSnapshot.builder() .transferredBytes(0L) .build()), + new PauseObservable(), UploadFileRequest.builder() .putObjectRequest(p -> p.key("key").bucket("bucket")).source(Paths.get( "test.txt")) diff --git a/services/s3/pom.xml b/services/s3/pom.xml index 4dac961a5bfa..86a0df660121 100644 --- a/services/s3/pom.xml +++ b/services/s3/pom.xml @@ -204,6 +204,11 @@ org.mockito mockito-junit-jupiter + + org.mockito + mockito-inline + test + net.bytebuddy byte-buddy diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CancelledSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CancelledSubscriber.java new file mode 100644 index 000000000000..a9a010d48983 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/CancelledSubscriber.java @@ -0,0 +1,48 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; + +@SdkInternalApi +public final class CancelledSubscriber implements Subscriber { + + @Override + public void onSubscribe(Subscription subscription) { + if (subscription == null) { + throw new NullPointerException("Null subscription"); + } else { + subscription.cancel(); + } + } + + @Override + public void onNext(T t) { + } + + @Override + public void onError(Throwable error) { + if (error == null) { + throw new NullPointerException("Null error published"); + } + } + + @Override + public void onComplete() { + } +} \ No newline at end of file diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java new file mode 100644 index 000000000000..59be53e13642 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java @@ -0,0 +1,214 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.NumericUtils; +import software.amazon.awssdk.utils.Pair; + +@SdkInternalApi +public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { + + private static final Logger log = Logger.loggerFor(KnownContentLengthAsyncRequestBodySubscriber.class); + + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + private final AtomicInteger partNumber = new AtomicInteger(1); + private final MultipartUploadHelper multipartUploadHelper; + private final long partSize; + private final int partCount; + private final int numExistingParts; + private final String uploadId; + private final Collection> futures = new ConcurrentLinkedQueue<>(); + private final PutObjectRequest putObjectRequest; + private final CompletableFuture returnFuture; + private final Map completedParts; + private final Map existingParts; + private Subscription subscription; + private volatile boolean isDone; + private volatile boolean isPaused; + private volatile CompletableFuture completeMpuFuture; + + KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, + CompletableFuture returnFuture, + MultipartUploadHelper multipartUploadHelper) { + this.partSize = mpuRequestContext.partSize(); + this.partCount = determinePartCount(mpuRequestContext.contentLength(), partSize); + this.putObjectRequest = mpuRequestContext.request().left(); + this.returnFuture = returnFuture; + this.uploadId = mpuRequestContext.uploadId(); + this.existingParts = mpuRequestContext.existingParts(); + this.numExistingParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); + this.completedParts = new ConcurrentHashMap<>(); + this.multipartUploadHelper = multipartUploadHelper; + } + + private int determinePartCount(long contentLength, long partSize) { + return (int) Math.ceil(contentLength / (double) partSize); + } + + public S3ResumeToken pause() { + isPaused = true; + + if (completeMpuFuture != null && completeMpuFuture.isDone()) { + return null; + } + + if (completeMpuFuture != null && !completeMpuFuture.isDone()) { + completeMpuFuture.cancel(true); + } + + long numPartsCompleted = 0; + for (CompletableFuture cf : futures) { + if (!cf.isDone()) { + cf.cancel(true); + } else { + numPartsCompleted++; + } + } + + return S3ResumeToken.builder() + .uploadId(uploadId) + .partSize(partSize) + .totalNumParts((long) partCount) + .numPartsCompleted(numPartsCompleted + numExistingParts) + .build(); + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } + this.subscription = s; + s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + if (shouldFailRequest()) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } + }); + } + + @Override + public void onNext(AsyncRequestBody asyncRequestBody) { + if (isPaused) { + return; + } + + if (existingParts.containsKey(partNumber.get())) { + partNumber.getAndIncrement(); + asyncRequestBody.subscribe(new CancelledSubscriber<>()); + subscription.request(1); + return; + } + + asyncRequestBodyInFlight.incrementAndGet(); + UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber.getAndIncrement(), + uploadId); + + Consumer completedPartConsumer = + completedPart -> completedParts.put(completedPart.partNumber(), completedPart); + multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, + Pair.of(uploadRequest, asyncRequestBody)) + .whenComplete((r, t) -> { + if (t != null) { + if (shouldFailRequest()) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, + putObjectRequest); + } + } else { + completeMultipartUploadIfFinished(asyncRequestBodyInFlight.decrementAndGet()); + } + }); + subscription.request(1); + } + + private boolean shouldFailRequest() { + return failureActionInitiated.compareAndSet(false, true) && !isPaused; + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Received onError ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } + + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + isDone = true; + if (!isPaused) { + completeMultipartUploadIfFinished(asyncRequestBodyInFlight.get()); + } + } + + private void completeMultipartUploadIfFinished(int requestsInFlight) { + if (isDone && requestsInFlight == 0) { + CompletedPart[] parts; + if (existingParts.isEmpty()) { + parts = completedParts.values().toArray(new CompletedPart[0]); + } else if (!completedParts.isEmpty()) { + // List of CompletedParts needs to be in ascending order + parts = mergeCompletedParts(); + } else { + parts = existingParts.values().toArray(new CompletedPart[0]); + } + completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, + putObjectRequest); + } + } + + private CompletedPart[] mergeCompletedParts() { + CompletedPart[] merged = new CompletedPart[partCount]; + int currPart = 1; + while (currPart < partCount + 1) { + CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) : + completedParts.get(currPart); + merged[currPart - 1] = completedPart; + currPart++; + } + return merged; + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java new file mode 100644 index 000000000000..6c4b978e4183 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContext.java @@ -0,0 +1,145 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.utils.Pair; + +@SdkInternalApi +public class MpuRequestContext { + + private final Pair request; + private final Long contentLength; + private final Long partSize; + private final Long numPartsCompleted; + private final String uploadId; + private final Map existingParts; + + protected MpuRequestContext(Builder builder) { + this.request = builder.request; + this.contentLength = builder.contentLength; + this.partSize = builder.partSize; + this.uploadId = builder.uploadId; + this.existingParts = builder.existingParts; + this.numPartsCompleted = builder.numPartsCompleted; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MpuRequestContext that = (MpuRequestContext) o; + + return Objects.equals(request, that.request) && Objects.equals(contentLength, that.contentLength) + && Objects.equals(partSize, that.partSize) && Objects.equals(numPartsCompleted, that.numPartsCompleted) + && Objects.equals(uploadId, that.uploadId) && Objects.equals(existingParts, that.existingParts); + } + + @Override + public int hashCode() { + int result = request != null ? request.hashCode() : 0; + result = 31 * result + (uploadId != null ? uploadId.hashCode() : 0); + result = 31 * result + (existingParts != null ? existingParts.hashCode() : 0); + result = 31 * result + (contentLength != null ? contentLength.hashCode() : 0); + result = 31 * result + (partSize != null ? partSize.hashCode() : 0); + result = 31 * result + (numPartsCompleted != null ? numPartsCompleted.hashCode() : 0); + return result; + } + + public Pair request() { + return request; + } + + public Long contentLength() { + return contentLength; + } + + public Long partSize() { + return partSize; + } + + public Long numPartsCompleted() { + return numPartsCompleted; + } + + public String uploadId() { + return uploadId; + } + + public Map existingParts() { + return existingParts != null ? Collections.unmodifiableMap(existingParts) : null; + } + + public static final class Builder { + private Pair request; + private Long contentLength; + private Long partSize; + private Long numPartsCompleted; + private String uploadId; + private Map existingParts; + + private Builder() { + } + + public Builder request(Pair request) { + this.request = request; + return this; + } + + public Builder contentLength(Long contentLength) { + this.contentLength = contentLength; + return this; + } + + public Builder partSize(Long partSize) { + this.partSize = partSize; + return this; + } + + public Builder numPartsCompleted(Long numPartsCompleted) { + this.numPartsCompleted = numPartsCompleted; + return this; + } + + public Builder uploadId(String uploadId) { + this.uploadId = uploadId; + return this; + } + + public Builder existingParts(Map existingParts) { + this.existingParts = existingParts; + return this; + } + + public MpuRequestContext build() { + return new MpuRequestContext(this); + } + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index 9754d284f5b9..c5bb7fe286cc 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -24,6 +24,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; @@ -74,20 +75,20 @@ CompletableFuture createMultipartUpload(PutObject return createMultipartUploadFuture; } - void completeMultipartUpload(CompletableFuture returnFuture, + CompletableFuture completeMultipartUpload(CompletableFuture returnFuture, String uploadId, CompletedPart[] completedParts, PutObjectRequest putObjectRequest) { - genericMultipartHelper.completeMultipartUpload(putObjectRequest, - uploadId, - completedParts) - .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, - uploadId)) - .exceptionally(throwable -> { - genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", - throwable); - return null; - }); + CompletableFuture future = + genericMultipartHelper.completeMultipartUpload(putObjectRequest, uploadId, completedParts); + + future.handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, uploadId)) + .exceptionally(throwable -> { + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", throwable); + return null; + }); + + return future; } CompletableFuture sendIndividualUploadPartRequest(String uploadId, diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PausableUpload.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PausableUpload.java new file mode 100644 index 000000000000..2e0d1885d432 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/PausableUpload.java @@ -0,0 +1,27 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; + +@SdkProtectedApi +public interface PausableUpload { + + default S3ResumeToken pause() { + throw new UnsupportedOperationException(); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java index b29f176e6fb5..bff5d389e1e9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtils.java @@ -34,6 +34,8 @@ import software.amazon.awssdk.services.s3.model.CopyPartResult; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.ListPartsRequest; +import software.amazon.awssdk.services.s3.model.Part; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; @@ -61,13 +63,6 @@ public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRe return builder.uploadId(uploadId).partNumber(partNumber).build(); } - public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { - - CreateMultipartUploadRequest.Builder builder = CreateMultipartUploadRequest.builder(); - setSdkFields(builder, putObjectRequest); - return builder.build(); - } - public static CompleteMultipartUploadRequest toCompleteMultipartUploadRequest(PutObjectRequest putObjectRequest, String uploadId, CompletedPart[] parts) { CompleteMultipartUploadRequest.Builder builder = CompleteMultipartUploadRequest.builder(); @@ -75,6 +70,13 @@ public static CompleteMultipartUploadRequest toCompleteMultipartUploadRequest(Pu return builder.uploadId(uploadId).multipartUpload(c -> c.parts(parts)).build(); } + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { + + CreateMultipartUploadRequest.Builder builder = CreateMultipartUploadRequest.builder(); + setSdkFields(builder, putObjectRequest); + return builder.build(); + } + public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { // We can't set SdkFields directly because the fields in CopyObjectRequest do not match 100% with the ones in @@ -107,6 +109,18 @@ public static CompletedPart toCompletedPart(UploadPartResponse partResponse, int return builder.partNumber(partNumber).build(); } + public static CompletedPart toCompletedPart(Part part) { + CompletedPart.Builder builder = CompletedPart.builder(); + setSdkFields(builder, part); + return builder.build(); + } + + public static ListPartsRequest toListPartsRequest(String uploadId, PutObjectRequest putObjectRequest) { + ListPartsRequest.Builder builder = ListPartsRequest.builder(); + setSdkFields(builder, putObjectRequest); + return builder.uploadId(uploadId).build(); + } + private static void setSdkFields(SdkPojo targetBuilder, SdkPojo sourceObject) { setSdkFields(targetBuilder, sourceObject, new HashSet<>()); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 46caefca8d61..9cb1aa62a100 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -15,25 +15,25 @@ package software.amazon.awssdk.services.s3.internal.multipart; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.PAUSE_OBSERVABLE; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.RESUME_TOKEN; -import java.util.Collection; +import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.Consumer; -import java.util.stream.IntStream; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; +import java.util.concurrent.ConcurrentHashMap; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.ListPartsRequest; +import software.amazon.awssdk.services.s3.model.Part; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; +import software.amazon.awssdk.services.s3.paginators.ListPartsPublisher; +import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Pair; @@ -47,7 +47,6 @@ public final class UploadWithKnownContentLengthHelper { private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; private final GenericMultipartHelper genericMultipartHelper; - private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; private final MultipartUploadHelper multipartUploadHelper; @@ -71,7 +70,6 @@ public CompletableFuture uploadObject(PutObjectRequest putObj AsyncRequestBody asyncRequestBody, long contentLength) { CompletableFuture returnFuture = new CompletableFuture<>(); - try { if (contentLength > multipartUploadThresholdInBytes && contentLength > partSizeInBytes) { log.debug(() -> "Starting the upload as multipart upload request"); @@ -80,7 +78,6 @@ public CompletableFuture uploadObject(PutObjectRequest putObj log.debug(() -> "Starting the upload as a single upload part request"); multipartUploadHelper.uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); } - } catch (Throwable throwable) { returnFuture.completeExceptionally(throwable); } @@ -90,7 +87,21 @@ public CompletableFuture uploadObject(PutObjectRequest putObj private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, CompletableFuture returnFuture) { + S3ResumeToken resumeToken = putObjectRequest.overrideConfiguration() + .map(c -> c.executionAttributes() + .getAttribute(RESUME_TOKEN)).orElse(null); + + if (resumeToken == null) { + initiateNewUpload(putObjectRequest, contentLength, asyncRequestBody, returnFuture); + } else { + ResumeRequestContext resumeRequestContext = new ResumeRequestContext(resumeToken, putObjectRequest, contentLength, + asyncRequestBody, returnFuture); + resumePausedUpload(resumeRequestContext); + } + } + private void initiateNewUpload(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { CompletableFuture createMultipartUploadFuture = multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); @@ -99,160 +110,133 @@ private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); } else { log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); - doUploadInParts(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, - createMultipartUploadResponse.uploadId()); + uploadFromBeginning(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, + createMultipartUploadResponse.uploadId()); } }); } - private void doUploadInParts(Pair request, - long contentLength, - CompletableFuture returnFuture, - String uploadId) { + private void uploadFromBeginning(Pair request, long contentLength, + CompletableFuture returnFuture, String uploadId) { - long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); - int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); - if (optimalPartSize > partSizeInBytes) { - log.debug(() -> String.format("Configured partSize is %d, but using %d to prevent reaching maximum number of parts " - + "allowed", partSizeInBytes, optimalPartSize)); + long numPartsCompleted = 0; + long partSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(contentLength, partSize); + + if (partSize > partSizeInBytes) { + log.debug(() -> String.format("Configured partSize is %d, but using %d to prevent reaching maximum number of " + + "parts allowed", partSizeInBytes, partSize)); } log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, - optimalPartSize)); + partSize)); + + MpuRequestContext mpuRequestContext = MpuRequestContext.builder() + .request(request) + .contentLength(contentLength) + .partSize(partSize) + .uploadId(uploadId) + .existingParts(new ConcurrentHashMap<>()) + .numPartsCompleted(numPartsCompleted) + .build(); + + splitAndSubscribe(mpuRequestContext, returnFuture); + } - MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); + private void resumePausedUpload(ResumeRequestContext resumeContext) { + S3ResumeToken resumeToken = resumeContext.resumeToken; + String uploadId = resumeToken.uploadId(); + PutObjectRequest putObjectRequest = resumeContext.putObjectRequest; + Map existingParts = new ConcurrentHashMap<>(); + CompletableFuture listPartsFuture = identifyExistingPartsForResume(uploadId, putObjectRequest, existingParts); - request.right() - .split(b -> b.chunkSizeInBytes(mpuRequestContext.partSize) - .bufferSizeInBytes(maxMemoryUsageInBytes)) - .subscribe(new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, - returnFuture)); - } + int remainingParts = (int) (resumeToken.totalNumParts() - resumeToken.numPartsCompleted()); + log.debug(() -> String.format("Resuming a paused multipart upload, uploadId: %s, completedPartCount: %d, " + + "remainingPartCount: %d, partSize: %d", + uploadId, resumeToken.numPartsCompleted(), remainingParts, resumeToken.partSize())); - private static final class MpuRequestContext { - private final Pair request; - private final long contentLength; - private final long partSize; + CompletableFutureUtils.forwardExceptionTo(resumeContext.returnFuture, listPartsFuture); - private final String uploadId; + listPartsFuture.whenComplete((r, t) -> { + if (t != null) { + genericMultipartHelper.handleException(resumeContext.returnFuture, + () -> "Failed to resume because listParts failed", t); + return; + } - private MpuRequestContext(Pair request, - long contentLength, - long partSize, - String uploadId) { - this.request = request; - this.contentLength = contentLength; - this.partSize = partSize; - this.uploadId = uploadId; - } + Pair request = Pair.of(putObjectRequest, resumeContext.asyncRequestBody); + MpuRequestContext mpuRequestContext = MpuRequestContext.builder() + .request(request) + .contentLength(resumeContext.contentLength) + .partSize(resumeToken.partSize()) + .uploadId(uploadId) + .existingParts(existingParts) + .numPartsCompleted(resumeToken.numPartsCompleted()) + .build(); + + splitAndSubscribe(mpuRequestContext, resumeContext.returnFuture); + }); } - private class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { - - /** - * The number of AsyncRequestBody has been received but yet to be processed - */ - private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + private void splitAndSubscribe(MpuRequestContext mpuRequestContext, CompletableFuture returnFuture) { + KnownContentLengthAsyncRequestBodySubscriber subscriber = + new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); - /** - * Indicates whether CompleteMultipart has been initiated or not. - */ - private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); + attachSubscriberToObservable(subscriber, mpuRequestContext.request().left()); - private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + mpuRequestContext.request().right() + .split(b -> b.chunkSizeInBytes(mpuRequestContext.partSize()) + .bufferSizeInBytes(maxMemoryUsageInBytes)) + .subscribe(subscriber); + } - private final AtomicInteger partNumber = new AtomicInteger(1); + private CompletableFuture identifyExistingPartsForResume(String uploadId, PutObjectRequest putObjectRequest, + Map existingParts) { + ListPartsRequest request = SdkPojoConversionUtils.toListPartsRequest(uploadId, putObjectRequest); + ListPartsPublisher listPartsPublisher = s3AsyncClient.listPartsPaginator(request); + SdkPublisher partsPublisher = listPartsPublisher.parts(); + return partsPublisher.subscribe(part -> + existingParts.put(part.partNumber(), SdkPojoConversionUtils.toCompletedPart(part))); + } - private final AtomicReferenceArray completedParts; - private final String uploadId; - private final Collection> futures = new ConcurrentLinkedQueue<>(); + private void attachSubscriberToObservable(KnownContentLengthAsyncRequestBodySubscriber subscriber, + PutObjectRequest putObjectRequest) { + // observable will be present if TransferManager is used + putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes().getAttribute(PAUSE_OBSERVABLE)) + .ifPresent(p -> p.setPausableUpload(new DefaultPausableUpload(subscriber))); + } + private static final class ResumeRequestContext { + private final S3ResumeToken resumeToken; private final PutObjectRequest putObjectRequest; + private final long contentLength; + private final AsyncRequestBody asyncRequestBody; private final CompletableFuture returnFuture; - private Subscription subscription; - - private volatile boolean isDone; - KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, - CompletableFuture returnFuture) { - long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(mpuRequestContext.contentLength, - partSizeInBytes); - int partCount = genericMultipartHelper.determinePartCount(mpuRequestContext.contentLength, optimalPartSize); - this.putObjectRequest = mpuRequestContext.request.left(); + private ResumeRequestContext(S3ResumeToken resumeToken, + PutObjectRequest putObjectRequest, + long contentLength, + AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + this.resumeToken = resumeToken; + this.putObjectRequest = putObjectRequest; + this.contentLength = contentLength; + this.asyncRequestBody = asyncRequestBody; this.returnFuture = returnFuture; - this.completedParts = new AtomicReferenceArray<>(partCount); - this.uploadId = mpuRequestContext.uploadId; } + } - @Override - public void onSubscribe(Subscription s) { - if (this.subscription != null) { - log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); - subscription.cancel(); - return; - } - this.subscription = s; - s.request(1); - returnFuture.whenComplete((r, t) -> { - if (t != null) { - s.cancel(); - if (failureActionInitiated.compareAndSet(false, true)) { - multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); - } - } - }); - } + private static final class DefaultPausableUpload implements PausableUpload { - @Override - public void onNext(AsyncRequestBody asyncRequestBody) { - log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); - asyncRequestBodyInFlight.incrementAndGet(); - UploadPartRequest uploadRequest = - SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, - partNumber.getAndIncrement(), - uploadId); - - Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, - completedPart); - multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, - Pair.of(uploadRequest, asyncRequestBody)) - .whenComplete((r, t) -> { - if (t != null) { - if (failureActionInitiated.compareAndSet(false, true)) { - multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, - putObjectRequest); - } - } else { - completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); - } - }); - subscription.request(1); - } + private KnownContentLengthAsyncRequestBodySubscriber subscriber; - @Override - public void onError(Throwable t) { - log.debug(() -> "Received onError ", t); - if (failureActionInitiated.compareAndSet(false, true)) { - multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); - } + private DefaultPausableUpload(KnownContentLengthAsyncRequestBodySubscriber subscriber) { + this.subscriber = subscriber; } @Override - public void onComplete() { - log.debug(() -> "Received onComplete()"); - isDone = true; - completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + public S3ResumeToken pause() { + return subscriber.pause(); } - - private void completeMultipartUploadIfFinish(int requestsInFlight) { - if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { - CompletedPart[] parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); - multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest); - } - } - } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/PauseObservable.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/PauseObservable.java new file mode 100644 index 000000000000..49886c7beeac --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/PauseObservable.java @@ -0,0 +1,41 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.services.s3.internal.multipart.PausableUpload; + +@SdkProtectedApi +public class PauseObservable { + + private volatile PausableUpload pausableUpload; + + public void setPausableUpload(PausableUpload pausableUpload) { + this.pausableUpload = pausableUpload; + } + + public S3ResumeToken pause() { + // single part upload or TM is not used + if (pausableUpload == null) { + return null; + } + return pausableUpload.pause(); + } + + public PausableUpload pausableUpload() { + return pausableUpload; + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3PauseResumeExecutionAttribute.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3PauseResumeExecutionAttribute.java new file mode 100644 index 000000000000..3aae35725557 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3PauseResumeExecutionAttribute.java @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.interceptor.ExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; + +@SdkProtectedApi +public final class S3PauseResumeExecutionAttribute extends SdkExecutionAttribute { + public static final ExecutionAttribute RESUME_TOKEN = new ExecutionAttribute<>("ResumeToken"); + public static final ExecutionAttribute PAUSE_OBSERVABLE = new ExecutionAttribute<>("PauseObservable"); +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3ResumeToken.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3ResumeToken.java new file mode 100644 index 000000000000..2ec2223fcec5 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/S3ResumeToken.java @@ -0,0 +1,108 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import java.util.Objects; +import software.amazon.awssdk.annotations.SdkProtectedApi; + +@SdkProtectedApi +public class S3ResumeToken { + + private final String uploadId; + private final Long partSize; + private final Long totalNumParts; + private final Long numPartsCompleted; + + public S3ResumeToken(Builder builder) { + this.uploadId = builder.uploadId; + this.partSize = builder.partSize; + this.totalNumParts = builder.totalNumParts; + this.numPartsCompleted = builder.numPartsCompleted; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + S3ResumeToken that = (S3ResumeToken) o; + + return partSize == that.partSize && totalNumParts == that.totalNumParts && numPartsCompleted == that.numPartsCompleted + && Objects.equals(uploadId, that.uploadId); + } + + @Override + public int hashCode() { + return Objects.hashCode(uploadId); + } + + public String uploadId() { + return uploadId; + } + + public Long partSize() { + return partSize; + } + + public Long totalNumParts() { + return totalNumParts; + } + + public Long numPartsCompleted() { + return numPartsCompleted; + } + + public static final class Builder { + private String uploadId; + private Long partSize; + private Long totalNumParts; + private Long numPartsCompleted; + + private Builder() { + } + + public Builder uploadId(String uploadId) { + this.uploadId = uploadId; + return this; + } + + public Builder partSize(Long partSize) { + this.partSize = partSize; + return this; + } + + public Builder totalNumParts(Long totalNumParts) { + this.totalNumParts = totalNumParts; + return this; + } + + public Builder numPartsCompleted(Long numPartsCompleted) { + this.numPartsCompleted = numPartsCompleted; + return this; + } + + public S3ResumeToken build() { + return new S3ResumeToken(this); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java new file mode 100644 index 000000000000..0ffbab391b3c --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -0,0 +1,142 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.Pair; + +public class KnownContentLengthAsyncRequestBodySubscriberTest { + + // Should contain four parts: [8KB, 8KB, 8KB, 1KB] + private static final long MPU_CONTENT_SIZE = 25 * 1024; + private static final long PART_SIZE = 8 * 1024; + private static final int TOTAL_NUM_PARTS = 4; + private static final String UPLOAD_ID = "1234"; + private static RandomTempFile testFile; + private AsyncRequestBody asyncRequestBody; + private PutObjectRequest putObjectRequest; + private S3AsyncClient s3AsyncClient; + private MultipartUploadHelper multipartUploadHelper; + + @BeforeAll + public static void beforeAll() throws IOException { + testFile = new RandomTempFile("testfile.dat", MPU_CONTENT_SIZE); + } + + @AfterAll + public static void afterAll() { + testFile.delete(); + } + + @BeforeEach + public void beforeEach() { + s3AsyncClient = mock(S3AsyncClient.class); + multipartUploadHelper = mock(MultipartUploadHelper.class); + asyncRequestBody = AsyncRequestBody.fromFile(testFile); + putObjectRequest = PutObjectRequest.builder().bucket("bucket").key("key").build(); + } + + @Test + void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() { + CompletableFuture completeMpuFuture = new CompletableFuture<>(); + int numExistingParts = 2; + S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + verifyResumeToken(resumeToken, numExistingParts); + assertThat(completeMpuFuture).isCancelled(); + } + + @Test + void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { + CompletableFuture completeMpuFuture = + CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build()); + int numExistingParts = 2; + S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + assertThat(resumeToken).isNull(); + } + + @Test + void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() { + CompletableFuture completeMpuFuture = null; + int numExistingParts = 2; + S3ResumeToken resumeToken = configureSubscriberAndPause(numExistingParts, completeMpuFuture); + + verifyResumeToken(resumeToken, numExistingParts); + } + + private S3ResumeToken configureSubscriberAndPause(int numExistingParts, + CompletableFuture completeMpuFuture) { + Map existingParts = existingParts(numExistingParts); + KnownContentLengthAsyncRequestBodySubscriber subscriber = subscriber(putObjectRequest, asyncRequestBody, existingParts); + + when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class), + any(CompletedPart[].class), any(PutObjectRequest.class))).thenReturn(completeMpuFuture); + subscriber.onComplete(); + return subscriber.pause(); + } + + private KnownContentLengthAsyncRequestBodySubscriber subscriber(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + Map existingParts) { + + MpuRequestContext mpuRequestContext = MpuRequestContext.builder() + .request(Pair.of(putObjectRequest, asyncRequestBody)) + .contentLength(MPU_CONTENT_SIZE) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .existingParts(existingParts) + .numPartsCompleted((long) existingParts.size()) + .build(); + + return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, new CompletableFuture<>(), multipartUploadHelper); + } + + private Map existingParts(int numExistingParts) { + Map existingParts = new ConcurrentHashMap<>(); + for (int i = 1; i <= numExistingParts; i++) { + existingParts.put(i, CompletedPart.builder().partNumber(i).build()); + } + return existingParts; + } + + private void verifyResumeToken(S3ResumeToken s3ResumeToken, int numExistingParts) { + assertThat(s3ResumeToken).isNotNull(); + assertThat(s3ResumeToken.uploadId()).isEqualTo(UPLOAD_ID); + assertThat(s3ResumeToken.partSize()).isEqualTo(PART_SIZE); + assertThat(s3ResumeToken.totalNumParts()).isEqualTo(TOTAL_NUM_PARTS); + assertThat(s3ResumeToken.numPartsCompleted()).isEqualTo(numExistingParts); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java new file mode 100644 index 000000000000..c858e7e8e9ec --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuRequestContextTest.java @@ -0,0 +1,73 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.utils.Pair; + +public class MpuRequestContextTest { + + private static final Pair REQUEST = Pair.of(PutObjectRequest.builder().build(), AsyncRequestBody.empty()); + private static final long CONTENT_LENGTH = 999; + private static final long PART_SIZE = 111; + private static final long NUM_PARTS_COMPLETED = 3; + private static final String UPLOAD_ID = "55555"; + private static final Map EXISTING_PARTS = new ConcurrentHashMap<>(); + + @Test + public void mpuRequestContext_withValues_buildsCorrectly() { + MpuRequestContext mpuRequestContext = MpuRequestContext.builder() + .request(REQUEST) + .contentLength(CONTENT_LENGTH) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .existingParts(EXISTING_PARTS) + .numPartsCompleted(NUM_PARTS_COMPLETED) + .build(); + + assertThat(mpuRequestContext.request()).isEqualTo(REQUEST); + assertThat(mpuRequestContext.contentLength()).isEqualTo(CONTENT_LENGTH); + assertThat(mpuRequestContext.partSize()).isEqualTo(PART_SIZE); + assertThat(mpuRequestContext.uploadId()).isEqualTo(UPLOAD_ID); + assertThat(mpuRequestContext.existingParts()).isEqualTo(EXISTING_PARTS); + assertThat(mpuRequestContext.numPartsCompleted()).isEqualTo(NUM_PARTS_COMPLETED); + } + + @Test + public void mpuRequestContext_default_buildsCorrectly() { + MpuRequestContext mpuRequestContext = MpuRequestContext.builder().build(); + + assertThat(mpuRequestContext.request()).isNull(); + assertThat(mpuRequestContext.contentLength()).isNull(); + assertThat(mpuRequestContext.partSize()).isNull(); + assertThat(mpuRequestContext.uploadId()).isNull(); + assertThat(mpuRequestContext.existingParts()).isNull(); + assertThat(mpuRequestContext.numPartsCompleted()).isNull(); + } + + @Test + void testEqualsAndHashCodeContract() { + EqualsVerifier.forClass(MpuRequestContext.class); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java index 435d5b406189..23fe07ab2743 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java @@ -19,6 +19,9 @@ import static org.mockito.Mockito.when; import java.util.concurrent.CompletableFuture; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; @@ -26,6 +29,9 @@ import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; public final class MpuTestUtils { @@ -62,4 +68,33 @@ public static void stubSuccessfulCompleteMultipartCall(String bucket, String key when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) .thenReturn(completeMultipartUploadFuture); } + + public static void stubSuccessfulUploadPartCalls(S3AsyncClient s3AsyncClient) { + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenAnswer(new Answer>() { + + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); + + return CompletableFuture.completedFuture(UploadPartResponse.builder() + .build()); + } + }); + } + + public static S3ResumeToken s3ResumeToken(long numPartsCompleted, long partSize, long contentLength, String uploadId) { + return S3ResumeToken.builder() + .uploadId(uploadId) + .partSize(partSize) + .numPartsCompleted(numPartsCompleted) + .totalNumParts(determinePartCount(contentLength, partSize)) + .build(); + } + + public static long determinePartCount(long contentLength, long partSize) { + return (long) Math.ceil(contentLength / (double) partSize); + } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java index 0c7e79f2d2c5..0f3ab5b5589f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/SdkPojoConversionUtilsTest.java @@ -45,6 +45,7 @@ import software.amazon.awssdk.services.s3.model.CopyPartResult; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.ListPartsRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3ResponseMetadata; @@ -204,7 +205,7 @@ void toCompletedPart_putObject_shouldCopyProperties() { @Test void toCompleteMultipartUploadRequest_putObject_shouldCopyProperties() { PutObjectRequest randomObject = randomPutObjectRequest(); - CompletedPart parts[] = new CompletedPart[1]; + CompletedPart[] parts = new CompletedPart[1]; CompletedPart completedPart = CompletedPart.builder().partNumber(1).build(); parts[0] = completedPart; CompleteMultipartUploadRequest convertedObject = @@ -218,6 +219,18 @@ void toCompleteMultipartUploadRequest_putObject_shouldCopyProperties() { assertThat(convertedObject.multipartUpload().parts()).contains(completedPart); } + @Test + void toListPartsRequest_putObject_shouldCopyProperties() { + PutObjectRequest randomObject = randomPutObjectRequest(); + ListPartsRequest convertedObject = SdkPojoConversionUtils.toListPartsRequest("uploadId", randomObject); + Set fieldsToIgnore = new HashSet<>(); + + verifyFieldsAreCopied(randomObject, convertedObject, fieldsToIgnore, + PutObjectRequest.builder().sdkFields(), + ListPartsRequest.builder().sdkFields()); + assertThat(convertedObject.uploadId()).isEqualTo("uploadId"); + } + private static void verifyFieldsAreCopied(SdkPojo requestConvertedFrom, SdkPojo requestConvertedTo, Set fieldsToIgnore, diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java index c18d4c993114..d17ab358c527 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelperTest.java @@ -19,11 +19,17 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.s3ResumeToken; import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall; +import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.PAUSE_OBSERVABLE; +import static software.amazon.awssdk.services.s3.multipart.S3PauseResumeExecutionAttribute.RESUME_TOKEN; import java.io.IOException; import java.nio.ByteBuffer; @@ -32,13 +38,14 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; @@ -50,19 +57,26 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.ListPartsRequest; +import software.amazon.awssdk.services.s3.model.Part; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.services.s3.multipart.PauseObservable; +import software.amazon.awssdk.services.s3.multipart.S3ResumeToken; +import software.amazon.awssdk.services.s3.paginators.ListPartsPublisher; import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.CompletableFutureUtils; @@ -140,8 +154,8 @@ void uploadObject_unKnownContentLengthDoesNotExceedPartSize_shouldUploadInOneChu void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); - MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); - stubSuccessfulUploadPartCalls(); + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(s3AsyncClient); stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); @@ -178,7 +192,7 @@ void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU(AsyncRequ void mpu_onePartFailed_shouldFailOtherPartsAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(MPU_CONTENT_SIZE); - MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); CompletableFuture ongoingRequest = new CompletableFuture<>(); SdkClientException exception = SdkClientException.create("request failed"); @@ -239,12 +253,12 @@ void upload_knownContentLengthCancelResponseFuture_shouldCancelUploadPart() { CompletableFuture createMultipartFuture = new CompletableFuture<>(); - MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); CompletableFuture ongoingRequest = new CompletableFuture<>(); - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), - any(AsyncRequestBody.class))).thenReturn(ongoingRequest); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), + any(AsyncRequestBody.class))).thenReturn(ongoingRequest); CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); @@ -267,7 +281,7 @@ void upload_knownContentLengthCancelResponseFuture_shouldCancelUploadPart() { void uploadObject_createMultipartUploadFailed_shouldFail(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); - SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); + SdkClientException exception = SdkClientException.create("CreateMultipartUpload failed"); CompletableFuture createMultipartUploadFuture = CompletableFutureUtils.failedFuture(exception); @@ -286,8 +300,8 @@ void uploadObject_createMultipartUploadFailed_shouldFail(AsyncRequestBody asyncR void uploadObject_completeMultipartFailed_shouldFailAndAbort(AsyncRequestBody asyncRequestBody) { PutObjectRequest putObjectRequest = putObjectRequest(null); - MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); - stubSuccessfulUploadPartCalls(); + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(s3AsyncClient); SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); @@ -315,8 +329,8 @@ void uploadObject_requestBodyOnError_shouldFailAndAbort(boolean contentLengthKno Long contentLength = contentLengthKnown ? MPU_CONTENT_SIZE : null; ErroneousAsyncRequestBody erroneousAsyncRequestBody = new ErroneousAsyncRequestBody(contentLength, exception); - MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); - stubSuccessfulUploadPartCalls(); + stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(s3AsyncClient); when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); @@ -327,6 +341,42 @@ void uploadObject_requestBodyOnError_shouldFailAndAbort(boolean contentLengthKno .hasRootCause(exception); } + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 3, 4}) + void uploadObject_withResumeToken_shouldInvokeListPartsAndSkipExistingParts(int numExistingParts) { + S3ResumeToken resumeToken = s3ResumeToken(numExistingParts, PART_SIZE, MPU_CONTENT_SIZE, "uploadId"); + PutObjectRequest putObjectRequest = putObjectRequestWithResumeToken(MPU_CONTENT_SIZE, resumeToken); + ListPartsRequest request = SdkPojoConversionUtils.toListPartsRequest("uploadId", putObjectRequest); + ListPartsPublisher mockPublisher = mock(ListPartsPublisher.class); + when(s3AsyncClient.listPartsPaginator(request)).thenReturn(mockPublisher); + when(mockPublisher.parts()).thenReturn(new TestPartPublisher(numExistingParts)); + + stubSuccessfulUploadPartCalls(s3AsyncClient); + stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); + + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)).join(); + + ArgumentCaptor listPartsRequestArgumentCaptor = ArgumentCaptor.forClass(ListPartsRequest.class); + verify(s3AsyncClient).listPartsPaginator(listPartsRequestArgumentCaptor.capture()); + assertThat(putObjectRequest.overrideConfiguration().get().executionAttributes().getAttribute(PAUSE_OBSERVABLE).pausableUpload()).isNotNull(); + + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + int numTotalParts = 4; + int numPartsToSend = numTotalParts - numExistingParts; + verify(s3AsyncClient, times(numPartsToSend)).uploadPart(requestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture()); + + ArgumentCaptor completeMpuArgumentCaptor = ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); + verify(s3AsyncClient).completeMultipartUpload(completeMpuArgumentCaptor.capture()); + + CompleteMultipartUploadRequest actualRequest = completeMpuArgumentCaptor.getValue(); + assertThat(actualRequest.multipartUpload().parts()).isEqualTo(completedParts(numTotalParts)); + } + + private List completedParts(int totalNumParts) { + return IntStream.range(1, totalNumParts + 1).mapToObj(i -> CompletedPart.builder().partNumber(i).build()).collect(Collectors.toList()); + } + private static PutObjectRequest putObjectRequest(Long contentLength) { return PutObjectRequest.builder() .bucket(BUCKET) @@ -335,37 +385,27 @@ private static PutObjectRequest putObjectRequest(Long contentLength) { .build(); } - private void stubSuccessfulUploadPartCalls() { - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) - .thenAnswer(new Answer>() { - int numberOfCalls = 0; + private static PutObjectRequest putObjectRequestWithResumeToken(Long contentLength, S3ResumeToken resumeToken) { + return putObjectRequest(contentLength).toBuilder() + .overrideConfiguration( + o -> o.putExecutionAttribute(RESUME_TOKEN, resumeToken) + .putExecutionAttribute(PAUSE_OBSERVABLE, new PauseObservable())) + .build(); - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); - // Draining the request body - AsyncRequestBody.subscribe(b -> {}); - - numberOfCalls++; - return CompletableFuture.completedFuture(UploadPartResponse.builder() - .checksumCRC32("crc" + numberOfCalls) - .build()); - } - }); } private OngoingStubbing> stubFailedUploadPartCalls(OngoingStubbing> stubbing, Exception exception) { return stubbing.thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); - // Draining the request body - AsyncRequestBody.subscribe(b -> {}); + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); - return CompletableFutureUtils.failedFuture(exception); - } - }); + return CompletableFutureUtils.failedFuture(exception); + } + }); } private static class UnknownContentLengthAsyncRequestBody implements AsyncRequestBody { @@ -425,4 +465,39 @@ public void cancel() { } } -} + + private static class TestPartPublisher implements SdkPublisher { + private int existingParts; + private int currentPart = 1; + + TestPartPublisher(int existingParts) { + this.existingParts = existingParts; + } + + @Override + public void subscribe(Subscriber subscriber) { + subscriber.onSubscribe(new Subscription() { + @Override + public void request(long n) { + if (n <= 0) { + subscriber.onError(new IllegalArgumentException("Demand must be positive")); + return; + } + + if (existingParts == 0) { + subscriber.onComplete(); + } + + while(existingParts > 0) { + existingParts--; + subscriber.onNext(Part.builder().partNumber(currentPart++).build()); + } + } + + @Override + public void cancel() {} + }); + } + } + +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/S3ResumeTokenTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/S3ResumeTokenTest.java new file mode 100644 index 000000000000..d31e3be7f2f2 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/multipart/S3ResumeTokenTest.java @@ -0,0 +1,59 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class S3ResumeTokenTest { + + private static final String UPLOAD_ID = "uploadId"; + private static final long PART_SIZE = 99; + private static final long TOTAL_NUM_PARTS = 20; + private static final long NUM_PARTS_COMPLETED = 2; + + @Test + public void s3ResumeToken_withValues_buildsCorrectly() { + S3ResumeToken token = S3ResumeToken.builder() + .uploadId(UPLOAD_ID) + .partSize(PART_SIZE) + .totalNumParts(TOTAL_NUM_PARTS) + .numPartsCompleted(NUM_PARTS_COMPLETED) + .build(); + + assertThat(token.uploadId()).isEqualTo(UPLOAD_ID); + assertThat(token.partSize()).isEqualTo(PART_SIZE); + assertThat(token.totalNumParts()).isEqualTo(TOTAL_NUM_PARTS); + assertThat(token.numPartsCompleted()).isEqualTo(NUM_PARTS_COMPLETED); + } + + @Test + public void s3ResumeToken_default_buildsCorrectly() { + S3ResumeToken token = S3ResumeToken.builder().build(); + + assertThat(token.uploadId()).isNull(); + assertThat(token.partSize()).isNull(); + assertThat(token.totalNumParts()).isNull(); + assertThat(token.numPartsCompleted()).isNull(); + } + + @Test + void testEqualsAndHashCodeContract() { + EqualsVerifier.forClass(S3ResumeToken.class); + } +}