diff --git a/services/docdb/pom.xml b/services/docdb/pom.xml index 045c597f4d47..bac9cdfcd2cb 100644 --- a/services/docdb/pom.xml +++ b/services/docdb/pom.xml @@ -56,11 +56,6 @@ aws-query-protocol ${awsjavasdk.version} - - software.amazon.awssdk - profiles - ${awsjavasdk.version} - software.amazon.awssdk http-auth-aws diff --git a/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java b/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java index 8db75eaf6bee..4d48de11e262 100644 --- a/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java +++ b/services/docdb/src/main/java/software/amazon/awssdk/services/docdb/internal/RdsPresignInterceptor.java @@ -15,21 +15,20 @@ package software.amazon.awssdk.services.docdb.internal; -import static software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute.AWS_CREDENTIALS; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.CredentialUtils; -import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; import software.amazon.awssdk.core.Protocol; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.exception.SdkClientException; @@ -40,7 +39,13 @@ import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.SignRequest; +import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; +import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.docdb.model.DocDbRequest; @@ -79,49 +84,39 @@ public interface PresignableRequest { private final Class requestClassToPreSign; - private final Clock signingOverrideClock; + private final Clock signingClockOverride; protected RdsPresignInterceptor(Class requestClassToPreSign) { this(requestClassToPreSign, null); } - protected RdsPresignInterceptor(Class requestClassToPreSign, Clock signingOverrideClock) { + protected RdsPresignInterceptor(Class requestClassToPreSign, Clock signingClockOverride) { this.requestClassToPreSign = requestClassToPreSign; - this.signingOverrideClock = signingOverrideClock; + this.signingClockOverride = signingClockOverride; } @Override public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, - ExecutionAttributes executionAttributes) { + ExecutionAttributes executionAttributes) { SdkHttpRequest request = context.httpRequest(); - SdkRequest originalRequest = context.request(); - if (!requestClassToPreSign.isInstance(originalRequest)) { - return request; - } - - if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { - return request; + PresignableRequest presignableRequest = toPresignableRequest(request, context); + if (presignableRequest == null) { + return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); - + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); String sourceRegion = presignableRequest.getSourceRegion(); - if (sourceRegion == null) { - return request; - } - - String destinationRegion = executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id(); - + String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); SdkHttpFullRequest.Builder marshalledRequest = presignableRequest.marshall().toBuilder().uri(endpoint); SdkHttpFullRequest requestToPresign = - marshalledRequest.method(SdkHttpMethod.GET) - .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) - .removeQueryParameter(PARAM_SOURCE_REGION) - .build(); + marshalledRequest.method(SdkHttpMethod.GET) + .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); - requestToPresign = presignRequest(requestToPresign, executionAttributes, sourceRegion); + requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -140,39 +135,93 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, */ protected abstract PresignableRequest adaptRequest(T originalRequest); - private SdkHttpFullRequest presignRequest(SdkHttpFullRequest request, - ExecutionAttributes attributes, - String signingRegion) { + /** + * Converts the request to a PresignableRequest if possible. + */ + private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context.ModifyHttpRequest context) { + SdkRequest originalRequest = context.request(); + if (!requestClassToPreSign.isInstance(originalRequest)) { + return null; + } + if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { + return null; + } + PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); + String sourceRegion = presignableRequest.getSourceRegion(); + if (sourceRegion == null) { + return null; + } + return presignableRequest; + } + + /** + * Presign the provided HTTP request using SRA HttpSigner + */ + private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + String signingRegion) { + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); - Aws4Signer signer = Aws4Signer.create(); - Aws4PresignerParams presignerParams = Aws4PresignerParams.builder() - .signingRegion(Region.of(signingRegion)) - .signingName(SERVICE_NAME) - .signingClockOverride(signingOverrideClock) - .awsCredentials(resolveCredentials(attributes)) - .build(); - return signer.presign(request, presignerParams); + Instant signingInstant; + if (signingClockOverride != null) { + signingInstant = signingClockOverride.instant(); + } else { + signingInstant = Instant.now(); + } + // A fixed signing clock is used so that the current time used by the signing logic, as well as to + // determine expiration are the same. + Clock signingClock = Clock.fixed(signingInstant, ZoneOffset.UTC); + Duration expirationDuration = Duration.ofDays(7); + return doSraPresign(request, selectedAuthScheme, signingRegion, signingClock, expirationDuration); + } + + private SdkHttpFullRequest doSraPresign(SdkHttpFullRequest request, + SelectedAuthScheme selectedAuthScheme, + String signingRegion, + Clock signingClock, + Duration expirationDuration) { + CompletableFuture identityFuture = selectedAuthScheme.identity(); + T identity = CompletableFutureUtils.joinLikeSync(identityFuture); + + // Pre-signed URL puts auth info in query string, does not sign the payload, and has an expiry. + SignRequest.Builder signRequestBuilder = SignRequest + .builder(identity) + .putProperty(AwsV4FamilyHttpSigner.AUTH_LOCATION, AwsV4FamilyHttpSigner.AuthLocation.QUERY_STRING) + .putProperty(AwsV4FamilyHttpSigner.EXPIRATION_DURATION, expirationDuration) + .putProperty(HttpSigner.SIGNING_CLOCK, signingClock) + .request(request) + .payload(request.contentStreamProvider().orElse(null)); + AuthSchemeOption authSchemeOption = selectedAuthScheme.authSchemeOption(); + authSchemeOption.forEachSignerProperty(signRequestBuilder::putProperty); + // Override the region + signRequestBuilder.putProperty(AwsV4HttpSigner.REGION_NAME, signingRegion); + HttpSigner signer = selectedAuthScheme.signer(); + SignedRequest signedRequest = signer.sign(signRequestBuilder.build()); + return toSdkHttpFullRequest(signedRequest); } - private AwsCredentials resolveCredentials(ExecutionAttributes attributes) { - return attributes.getOptionalAttribute(SELECTED_AUTH_SCHEME) - .map(selectedAuthScheme -> selectedAuthScheme.identity()) - .map(identityFuture -> CompletableFutureUtils.joinLikeSync(identityFuture)) - .filter(identity -> identity instanceof AwsCredentialsIdentity) - .map(identity -> { - AwsCredentialsIdentity awsCredentialsIdentity = (AwsCredentialsIdentity) identity; - return CredentialUtils.toCredentials(awsCredentialsIdentity); - }).orElse(attributes.getAttribute(AWS_CREDENTIALS)); + private SdkHttpFullRequest toSdkHttpFullRequest(SignedRequest signedRequest) { + SdkHttpRequest request = signedRequest.request(); + + return SdkHttpFullRequest.builder() + .contentStreamProvider(signedRequest.payload().orElse(null)) + .protocol(request.protocol()) + .method(request.method()) + .host(request.host()) + .port(request.port()) + .encodedPath(request.encodedPath()) + .applyMutation(r -> request.forEachHeader(r::putHeader)) + .applyMutation(r -> request.forEachRawQueryParameter(r::putRawQueryParameter)) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); } private URI createEndpoint(String regionName, String serviceName, ExecutionAttributes attributes) { Region region = Region.of(regionName); - if (region == null) { throw SdkClientException.builder() .message("{" + serviceName + ", " + regionName + "} was not " - + "found in region metadata. Update to latest version of SDK and try again.") + + "found in region metadata. Update to latest version of SDK and try again.") .build(); } diff --git a/services/docdb/src/main/resources/codegen-resources/customization.config b/services/docdb/src/main/resources/codegen-resources/customization.config index 2f1a69a71316..16e12a0d7085 100644 --- a/services/docdb/src/main/resources/codegen-resources/customization.config +++ b/services/docdb/src/main/resources/codegen-resources/customization.config @@ -1,4 +1,6 @@ { + "useSraAuth": true, + "enableGenerateCompiledEndpointRules": true, "verifiedSimpleMethods" : [ "describeDBClusterParameterGroups", "describeDBClusterSnapshots", diff --git a/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestHandlerTest.java b/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestHandlerTest.java index 1a261fc2ebdd..62cb53e9b5c6 100644 --- a/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestHandlerTest.java +++ b/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestHandlerTest.java @@ -15,173 +15,345 @@ package software.amazon.awssdk.services.docdb.internal; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; import java.time.Clock; -import java.util.Calendar; -import java.util.GregorianCalendar; -import java.util.TimeZone; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; -import software.amazon.awssdk.core.Protocol; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.InterceptorContext; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.docdb.DocDbClient; +import software.amazon.awssdk.services.docdb.DocDbClientBuilder; +import software.amazon.awssdk.services.docdb.DocDbServiceClientConfiguration; +import software.amazon.awssdk.services.docdb.auth.scheme.DocDbAuthSchemeProvider; import software.amazon.awssdk.services.docdb.model.CopyDbClusterSnapshotRequest; -import software.amazon.awssdk.services.docdb.model.DocDbRequest; -import software.amazon.awssdk.services.docdb.transform.CopyDbClusterSnapshotRequestMarshaller; +import software.amazon.awssdk.utils.IoUtils; +import software.amazon.awssdk.utils.Validate; /** * Unit Tests for {@link RdsPresignInterceptor} */ -public class PresignRequestHandlerTest { - private static final AwsBasicCredentials CREDENTIALS = AwsBasicCredentials.create("foo", "bar"); - private static final Region DESTINATION_REGION = Region.of("us-west-2"); +class PresignRequestHandlerTest { + private static String TEST_KMS_KEY_ID = "arn:aws:kms:us-west-2:123456789012:key/" + + "11111111-2222-3333-4444-555555555555"; - private static final RdsPresignInterceptor presignInterceptor = new CopyDbClusterSnapshotPresignInterceptor(); - private final CopyDbClusterSnapshotRequestMarshaller marshaller = - new CopyDbClusterSnapshotRequestMarshaller(RdsPresignInterceptor.PROTOCOL_FACTORY); + @ParameterizedTest + @MethodSource("testCases") + public void testExpectations(TestCase testCase) { + // Arrange + CapturingInterceptor interceptor = new CapturingInterceptor(); + DocDbClientBuilder clientBuilder = client(interceptor, testCase.signingClockOverride); + testCase.clientConfigure.accept(clientBuilder); + DocDbClient client = clientBuilder.build(); - @Test - public void testSetsPresignedUrl() { - CopyDbClusterSnapshotRequest request = makeTestRequest(); - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + // Act + assertThatThrownBy(() -> testCase.clientConsumer.accept(client)) + .hasMessageContaining("boom!"); - assertNotNull(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - } + // Assert + SdkHttpFullRequest request = (SdkHttpFullRequest) interceptor.httpRequest(); + Map> rawQueryParameters = rawQueryParameters(request); - @Test - public void testComputesPresignedUrlCorrectlyForCopyDbClusterSnapshotRequest() { - // Note: test data was baselined by performing actual calls, with real - // credentials to RDS and checking that they succeeded. Then the - // request was recreated with all the same parameters but with test - // credentials. - final CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); - - Calendar c = new GregorianCalendar(); - c.setTimeZone(TimeZone.getTimeZone("UTC")); - // 20161221T180735Z - // Note: month is 0-based - c.set(2016, Calendar.DECEMBER, 21, 18, 7, 35); - - Clock signingDateOverride = Mockito.mock(Clock.class); - when(signingDateOverride.millis()).thenReturn(c.getTimeInMillis()); - - RdsPresignInterceptor interceptor = new CopyDbClusterSnapshotPresignInterceptor(signingDateOverride); - - SdkHttpRequest presignedRequest = modifyHttpRequest(interceptor, request, marshallRequest(request)); - - final String expectedPreSignedUrl = "https://rds.us-east-1.amazonaws.com?" + - "Action=CopyDBClusterSnapshot" + - "&Version=2014-10-31" + - "&SourceDBClusterSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + - "&TargetDBClusterSnapshotIdentifier=test-instance-ss-copy-2" + - "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333-4444-555555555555" + - "&DestinationRegion=us-west-2" + - "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + - "&X-Amz-Date=20161221T180735Z" + - "&X-Amz-SignedHeaders=host" + - "&X-Amz-Expires=604800" + - "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + - "&X-Amz-Signature=00822ebbba95e2e6ac09112aa85621fbef060a596e3e1480f9f4ac61493e9821"; - assertEquals(expectedPreSignedUrl, presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - } + // The following params should not be included in the outgoing request + assertFalse(rawQueryParameters.containsKey("SourceRegion")); + assertFalse(rawQueryParameters.containsKey("DestinationRegion")); - @Test - public void testSkipsPresigningIfUrlSet() { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceRegion("us-west-2") - .preSignedUrl("PRESIGNED") - .build(); + if (testCase.shouldContainPreSignedUrl) { + List rawPresignedUrlValue = rawQueryParameters.get("PreSignedUrl"); + assertNotNull(rawPresignedUrlValue); + assertTrue(rawPresignedUrlValue.size() == 1); + String presignedUrl = rawPresignedUrlValue.get(0); + assertNotNull(presignedUrl); + // Validate that the URL can be parsed back + URI presignedUrlAsUri = URI.create(presignedUrl); + assertNotNull(presignedUrlAsUri); + if (testCase.expectedDestinationRegion != null) { + assertTrue(presignedUrl.contains("DestinationRegion=" + testCase.expectedDestinationRegion)); + } + if (testCase.expectedUri != null) { + assertEquals(normalize(URI.create(testCase.expectedUri)), normalize(presignedUrlAsUri)); + } + } else { + assertFalse(rawQueryParameters.containsKey("PreSignedUrl")); + } + } + public static List testCases() { + return Arrays.asList( + builder("CopyDbClusterSnapshot - Sets pre-signed URL when sourceRegion is set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("CopyDbClusterSnapshot - Doesn't set pre-signed URL when sourceRegion is NOT set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder().build())) + .shouldContainPreSignedUrl(false) + .build(), + builder("CopyDbClusterSnapshot - Does not override pre-signed URL") + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-west-2") + .preSignedUrl("http://localhost?foo=bar") + .build())) + .shouldContainPreSignedUrl(true) + .expectedUri("http://localhost?foo=bar") + .build(), + builder("CopyDbClusterSnapshot - Fixed time") + .clientConfigure(c -> c.region(Region.US_WEST_2)) + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .signingClockOverride(Clock.fixed(Instant.parse("2016-12-21T18:07:35.000Z"), ZoneId.of("UTC"))) + .expectedUri(fixedTimePresignedUrl()) + .build(), - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + builder("createDBCluster With SourceRegion Sends Presigned Url") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID) + .sourceRegion("us-west-2"))) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("createDBCluster Without SourceRegion Does NOT Send PresignedUrl") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID))) + .shouldContainPreSignedUrl(false) + .build() + ); + } - assertEquals("PRESIGNED", presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); + private static CopyDbClusterSnapshotRequest.Builder makeTestRequestBuilder() { + return CopyDbClusterSnapshotRequest + .builder() + .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds" + + ":test-instance-ss-2016-12-20-23-19") + .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") + .kmsKeyId(TEST_KMS_KEY_ID); } - @Test - public void testSkipsPresigningIfSourceRegionNotSet() { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder().build(); + private static DocDbClientBuilder client(CapturingInterceptor interceptor, Clock signingClockOverride) { + DocDbClientBuilder builder = DocDbClient + .builder() + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("foo", "bar"))) + .region(Region.US_EAST_1) + .addPlugin(c -> { + // Adds the capturing interceptor. + DocDbServiceClientConfiguration.Builder config = + Validate.isInstanceOf(DocDbServiceClientConfiguration.Builder.class, c, + "\uD83E\uDD14"); + config.overrideConfiguration(oc -> oc.addExecutionInterceptor(interceptor)); + }); - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + if (signingClockOverride != null) { + // Adds a auth scheme wrapper that handles the clock override + builder.addPlugin(c -> { + DocDbServiceClientConfiguration.Builder config = + Validate.isInstanceOf(DocDbServiceClientConfiguration.Builder.class, c, "\uD83E\uDD14"); + config.authSchemeProvider(clockOverridingAuthScheme(config.authSchemeProvider(), signingClockOverride)); + }); + } + return builder; + } - assertNull(presignedRequest.rawQueryParameters().get("PreSignedUrl")); + private static DocDbAuthSchemeProvider clockOverridingAuthScheme(DocDbAuthSchemeProvider source, Clock signingClockOverride) { + return authSchemeParams -> { + List authSchemeOptions = source.resolveAuthScheme(authSchemeParams); + List result = new ArrayList<>(authSchemeOptions.size()); + for (AuthSchemeOption option : authSchemeOptions) { + if (option.schemeId().equals(AwsV4AuthScheme.SCHEME_ID)) { + option = option.toBuilder() + .putSignerProperty(AwsV4FamilyHttpSigner.SIGNING_CLOCK, signingClockOverride) + .build(); + } + result.add(option); + } + return result; + }; } - @Test - public void testParsesDestinationRegionfromRequestEndpoint() throws URISyntaxException { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceRegion("us-east-1") - .build(); - Region destination = Region.of("us-west-2"); - SdkHttpFullRequest marshalled = marshallRequest(request); + static String fixedTimePresignedUrl() { + return + "https://rds.us-east-1.amazonaws.com?" + + "Action=CopyDBClusterSnapshot" + + "&Version=2014-10-31" + + "&SourceDBClusterSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012" + + "%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + + "&TargetDBClusterSnapshotIdentifier=test-instance-ss-copy-2" + + "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333" + + "-4444-555555555555" + + "&DestinationRegion=us-west-2" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Date=20161221T180735Z" + + "&X-Amz-SignedHeaders=host" + + "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + + "&X-Amz-Expires=604800" + + "&X-Amz-Signature=00822ebbba95e2e6ac09112aa85621fbef060a596e3e1480f9f4ac61493e9821"; + } - final SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshalled); + private Map> rawQueryParameters(SdkHttpFullRequest request) { + // Retrieve back from the query parameters from the body, this is best-effort only. + try { + String decodedQueryParams = IoUtils.toUtf8String(request.contentStreamProvider().get().newStream()); + String[] keyValuePairs = decodedQueryParams.split("&"); + Map> result = new LinkedHashMap<>(); + for (String keyValuePair : keyValuePairs) { + String[] kvpParts = keyValuePair.split("=", 2); + String value = URLDecoder.decode(kvpParts.length > 1 ? kvpParts[1] : "", StandardCharsets.UTF_8.name()); + result.computeIfAbsent(kvpParts[0], x -> new ArrayList<>()).add(value); + } + return result; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } - final URI presignedUrl = new URI(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - assertTrue(presignedUrl.toString().contains("DestinationRegion=" + destination.id())); + static TestCaseBuilder builder(String name) { + return new TestCaseBuilder() + .clientConfigure(c -> { + }) + .name(name); } - @Test - public void testSourceRegionRemovedFromOriginalRequest() { - CopyDbClusterSnapshotRequest request = makeTestRequest(); - SdkHttpFullRequest marshalled = marshallRequest(request); - SdkHttpRequest actual = modifyHttpRequest(presignInterceptor, request, marshalled); + private static String normalize(URI uri) { + String uriAsString = uri.toString(); + int queryStart = uriAsString.indexOf('?'); + if (queryStart == -1) { + return uriAsString; + } + String uriQueryPrefix = uriAsString.substring(0, queryStart); + String query = uri.getQuery(); + if (query == null) { + return uriAsString; + } + if (!query.isEmpty()) { + String[] queryParts = query.split("&"); + query = Arrays.stream(queryParts) + .sorted() + .collect(Collectors.joining("&")); - assertFalse(actual.rawQueryParameters().containsKey("SourceRegion")); + } + return uriQueryPrefix + "?" + query; } - private SdkHttpFullRequest marshallRequest(CopyDbClusterSnapshotRequest request) { - SdkHttpFullRequest.Builder marshalled = marshaller.marshall(request).toBuilder(); + static class TestCase { + private final String name; + private final Consumer clientConfigure; + private final Consumer clientConsumer; + private final Boolean shouldContainPreSignedUrl; + private final String expectedDestinationRegion; + private final Clock signingClockOverride; + private final String expectedUri; - URI endpoint = new DefaultServiceEndpointBuilder("rds", Protocol.HTTPS.toString()) - .withRegion(DESTINATION_REGION) - .getServiceEndpoint(); - return marshalled.uri(endpoint).build(); + TestCase(TestCaseBuilder builder) { + this.name = Validate.notNull(builder.name, "name"); + this.clientConsumer = Validate.notNull(builder.clientConsumer, "clientConsumer"); + this.clientConfigure = Validate.notNull(builder.clientConfigure, "clientConfigure"); + this.shouldContainPreSignedUrl = builder.shouldContainPreSignedUrl; + this.expectedDestinationRegion = builder.expectedDestinationRegion; + this.signingClockOverride = builder.signingClockOverride; + this.expectedUri = builder.expectedUri; + } } - private ExecutionAttributes executionAttributes() { - return new ExecutionAttributes().putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, CREDENTIALS) - .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, DESTINATION_REGION) - .putAttribute(SdkExecutionAttribute.PROFILE_FILE_SUPPLIER, - ProfileFile::defaultProfileFile) - .putAttribute(SdkExecutionAttribute.PROFILE_NAME, "default"); - } + static class TestCaseBuilder { + private String name; + private Consumer clientConfigure; + private Consumer clientConsumer; + private Boolean shouldContainPreSignedUrl; + private String expectedDestinationRegion; + private Clock signingClockOverride; + private String expectedUri; + + private TestCaseBuilder name(String name) { + this.name = name; + return this; + } - private CopyDbClusterSnapshotRequest makeTestRequest() { - return CopyDbClusterSnapshotRequest.builder() - .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); + private TestCaseBuilder clientConfigure(Consumer clientConfigure) { + this.clientConfigure = clientConfigure; + return this; + } + + private TestCaseBuilder clientConsumer(Consumer clientConsumer) { + this.clientConsumer = clientConsumer; + return this; + } + + private TestCaseBuilder shouldContainPreSignedUrl(Boolean value) { + this.shouldContainPreSignedUrl = value; + return this; + } + + private TestCaseBuilder expectedDestinationRegion(String value) { + this.expectedDestinationRegion = value; + return this; + } + + public TestCaseBuilder signingClockOverride(Clock signingClockOverride) { + this.signingClockOverride = signingClockOverride; + return this; + } + + public TestCaseBuilder expectedUri(String expectedUri) { + this.expectedUri = expectedUri; + return this; + } + + public TestCase build() { + return new TestCase(this); + } } - private SdkHttpRequest modifyHttpRequest(ExecutionInterceptor interceptor, - DocDbRequest request, - SdkHttpFullRequest httpRequest) { - InterceptorContext context = InterceptorContext.builder().request(request).httpRequest(httpRequest).build(); - return interceptor.modifyHttpRequest(context, executionAttributes()); + static class CapturingInterceptor implements ExecutionInterceptor { + private static final RuntimeException BOOM = new RuntimeException("boom!"); + private Context.BeforeTransmission context; + private ExecutionAttributes executionAttributes; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.context = context; + this.executionAttributes = executionAttributes; + throw BOOM; + } + + public ExecutionAttributes executionAttributes() { + return executionAttributes; + } + + public SdkHttpRequest httpRequest() { + return context.httpRequest(); + } } } diff --git a/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestWireMockTest.java b/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestWireMockTest.java index 6342bbabc251..5b68da7dae9f 100644 --- a/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestWireMockTest.java +++ b/services/docdb/src/test/java/software/amazon/awssdk/services/docdb/internal/PresignRequestWireMockTest.java @@ -78,12 +78,6 @@ public void createDbClusterWithSourceRegionSendsPresignedUrl() { "CreateDBCluster"); } - @Test - public void createDBInstanceReadReplicaWithSourceRegionSendsPresignedUrl() { - verifyMethodCallSendsPresignedUrl(() -> client.createDBCluster(r -> r.sourceRegion("us-west-2")), - "CreateDBCluster"); - } - public void verifyMethodCallSendsPresignedUrl(Runnable methodCall, String actionName) { stubFor(any(anyUrl()).willReturn(aResponse().withStatus(200).withBody(""))); diff --git a/services/neptune/pom.xml b/services/neptune/pom.xml index a4acda0bb292..1b157de2c547 100644 --- a/services/neptune/pom.xml +++ b/services/neptune/pom.xml @@ -56,11 +56,6 @@ aws-query-protocol ${awsjavasdk.version} - - software.amazon.awssdk - profiles - ${awsjavasdk.version} - software.amazon.awssdk http-auth-aws diff --git a/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java b/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java index caba12352862..07dd4567100c 100644 --- a/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java +++ b/services/neptune/src/main/java/software/amazon/awssdk/services/neptune/internal/RdsPresignInterceptor.java @@ -15,21 +15,20 @@ package software.amazon.awssdk.services.neptune.internal; -import static software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute.AWS_CREDENTIALS; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.CredentialUtils; -import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; import software.amazon.awssdk.core.Protocol; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.exception.SdkClientException; @@ -40,7 +39,13 @@ import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.SignRequest; +import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; +import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.neptune.model.NeptuneRequest; @@ -80,49 +85,39 @@ public interface PresignableRequest { private final Class requestClassToPreSign; - private final Clock signingOverrideClock; + private final Clock signingClockOverride; protected RdsPresignInterceptor(Class requestClassToPreSign) { this(requestClassToPreSign, null); } - protected RdsPresignInterceptor(Class requestClassToPreSign, Clock signingOverrideClock) { + protected RdsPresignInterceptor(Class requestClassToPreSign, Clock signingClockOverride) { this.requestClassToPreSign = requestClassToPreSign; - this.signingOverrideClock = signingOverrideClock; + this.signingClockOverride = signingClockOverride; } @Override public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, - ExecutionAttributes executionAttributes) { + ExecutionAttributes executionAttributes) { SdkHttpRequest request = context.httpRequest(); - SdkRequest originalRequest = context.request(); - if (!requestClassToPreSign.isInstance(originalRequest)) { - return request; - } - - if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { - return request; + PresignableRequest presignableRequest = toPresignableRequest(request, context); + if (presignableRequest == null) { + return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); - + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); String sourceRegion = presignableRequest.getSourceRegion(); - if (sourceRegion == null) { - return request; - } - - String destinationRegion = executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id(); - + String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); SdkHttpFullRequest.Builder marshalledRequest = presignableRequest.marshall().toBuilder().uri(endpoint); SdkHttpFullRequest requestToPresign = - marshalledRequest.method(SdkHttpMethod.GET) - .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) - .removeQueryParameter(PARAM_SOURCE_REGION) - .build(); + marshalledRequest.method(SdkHttpMethod.GET) + .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); - requestToPresign = presignRequest(requestToPresign, executionAttributes, sourceRegion); + requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -141,48 +136,102 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, */ protected abstract PresignableRequest adaptRequest(T originalRequest); - private SdkHttpFullRequest presignRequest(SdkHttpFullRequest request, - ExecutionAttributes attributes, - String signingRegion) { + /** + * Converts the request to a PresignableRequest if possible. + */ + private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context.ModifyHttpRequest context) { + SdkRequest originalRequest = context.request(); + if (!requestClassToPreSign.isInstance(originalRequest)) { + return null; + } + if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { + return null; + } + PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); + String sourceRegion = presignableRequest.getSourceRegion(); + if (sourceRegion == null) { + return null; + } + return presignableRequest; + } + + /** + * Presign the provided HTTP request using SRA HttpSigner + */ + private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + String signingRegion) { + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); - Aws4Signer signer = Aws4Signer.create(); - Aws4PresignerParams presignerParams = Aws4PresignerParams.builder() - .signingRegion(Region.of(signingRegion)) - .signingName(SERVICE_NAME) - .signingClockOverride(signingOverrideClock) - .awsCredentials(resolveCredentials(attributes)) - .build(); - return signer.presign(request, presignerParams); + Instant signingInstant; + if (signingClockOverride != null) { + signingInstant = signingClockOverride.instant(); + } else { + signingInstant = Instant.now(); + } + // A fixed signing clock is used so that the current time used by the signing logic, as well as to + // determine expiration are the same. + Clock signingClock = Clock.fixed(signingInstant, ZoneOffset.UTC); + Duration expirationDuration = Duration.ofDays(7); + return doSraPresign(request, selectedAuthScheme, signingRegion, signingClock, expirationDuration); + } + + private SdkHttpFullRequest doSraPresign(SdkHttpFullRequest request, + SelectedAuthScheme selectedAuthScheme, + String signingRegion, + Clock signingClock, + Duration expirationDuration) { + CompletableFuture identityFuture = selectedAuthScheme.identity(); + T identity = CompletableFutureUtils.joinLikeSync(identityFuture); + + // Pre-signed URL puts auth info in query string, does not sign the payload, and has an expiry. + SignRequest.Builder signRequestBuilder = SignRequest + .builder(identity) + .putProperty(AwsV4FamilyHttpSigner.AUTH_LOCATION, AwsV4FamilyHttpSigner.AuthLocation.QUERY_STRING) + .putProperty(AwsV4FamilyHttpSigner.EXPIRATION_DURATION, expirationDuration) + .putProperty(HttpSigner.SIGNING_CLOCK, signingClock) + .request(request) + .payload(request.contentStreamProvider().orElse(null)); + AuthSchemeOption authSchemeOption = selectedAuthScheme.authSchemeOption(); + authSchemeOption.forEachSignerProperty(signRequestBuilder::putProperty); + // Override the region + signRequestBuilder.putProperty(AwsV4HttpSigner.REGION_NAME, signingRegion); + HttpSigner signer = selectedAuthScheme.signer(); + SignedRequest signedRequest = signer.sign(signRequestBuilder.build()); + return toSdkHttpFullRequest(signedRequest); } - private AwsCredentials resolveCredentials(ExecutionAttributes attributes) { - return attributes.getOptionalAttribute(SELECTED_AUTH_SCHEME) - .map(selectedAuthScheme -> selectedAuthScheme.identity()) - .map(identityFuture -> CompletableFutureUtils.joinLikeSync(identityFuture)) - .filter(identity -> identity instanceof AwsCredentialsIdentity) - .map(identity -> { - AwsCredentialsIdentity awsCredentialsIdentity = (AwsCredentialsIdentity) identity; - return CredentialUtils.toCredentials(awsCredentialsIdentity); - }).orElse(attributes.getAttribute(AWS_CREDENTIALS)); + private SdkHttpFullRequest toSdkHttpFullRequest(SignedRequest signedRequest) { + SdkHttpRequest request = signedRequest.request(); + + return SdkHttpFullRequest.builder() + .contentStreamProvider(signedRequest.payload().orElse(null)) + .protocol(request.protocol()) + .method(request.method()) + .host(request.host()) + .port(request.port()) + .encodedPath(request.encodedPath()) + .applyMutation(r -> request.forEachHeader(r::putHeader)) + .applyMutation(r -> request.forEachRawQueryParameter(r::putRawQueryParameter)) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); } private URI createEndpoint(String regionName, String serviceName, ExecutionAttributes attributes) { Region region = Region.of(regionName); - if (region == null) { throw SdkClientException.builder() .message("{" + serviceName + ", " + regionName + "} was not " - + "found in region metadata. Update to latest version of SDK and try again.") + + "found in region metadata. Update to latest version of SDK and try again.") .build(); } return new DefaultServiceEndpointBuilder(SERVICE_NAME, Protocol.HTTPS.toString()) - .withRegion(region) - .withProfileFile(attributes.getAttribute(SdkExecutionAttribute.PROFILE_FILE_SUPPLIER)) - .withProfileName(attributes.getAttribute(SdkExecutionAttribute.PROFILE_NAME)) - .withDualstackEnabled(attributes.getAttribute(AwsExecutionAttribute.DUALSTACK_ENDPOINT_ENABLED)) - .withFipsEnabled(attributes.getAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED)) - .getServiceEndpoint(); + .withRegion(region) + .withProfileFile(attributes.getAttribute(SdkExecutionAttribute.PROFILE_FILE_SUPPLIER)) + .withProfileName(attributes.getAttribute(SdkExecutionAttribute.PROFILE_NAME)) + .withDualstackEnabled(attributes.getAttribute(AwsExecutionAttribute.DUALSTACK_ENDPOINT_ENABLED)) + .withFipsEnabled(attributes.getAttribute(AwsExecutionAttribute.FIPS_ENDPOINT_ENABLED)) + .getServiceEndpoint(); } } diff --git a/services/neptune/src/test/java/software/amazon/awssdk/services/neptune/internal/PresignRequestHandlerTest.java b/services/neptune/src/test/java/software/amazon/awssdk/services/neptune/internal/PresignRequestHandlerTest.java index e0f083a4da88..e6058d6aa08b 100644 --- a/services/neptune/src/test/java/software/amazon/awssdk/services/neptune/internal/PresignRequestHandlerTest.java +++ b/services/neptune/src/test/java/software/amazon/awssdk/services/neptune/internal/PresignRequestHandlerTest.java @@ -15,173 +15,349 @@ package software.amazon.awssdk.services.neptune.internal; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; import java.time.Clock; -import java.util.Calendar; -import java.util.GregorianCalendar; -import java.util.TimeZone; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; -import software.amazon.awssdk.core.Protocol; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.InterceptorContext; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.neptune.NeptuneClient; +import software.amazon.awssdk.services.neptune.NeptuneClientBuilder; +import software.amazon.awssdk.services.neptune.NeptuneServiceClientConfiguration; +import software.amazon.awssdk.services.neptune.auth.scheme.NeptuneAuthSchemeProvider; import software.amazon.awssdk.services.neptune.model.CopyDbClusterSnapshotRequest; -import software.amazon.awssdk.services.neptune.model.NeptuneRequest; -import software.amazon.awssdk.services.neptune.transform.CopyDbClusterSnapshotRequestMarshaller; +import software.amazon.awssdk.utils.IoUtils; +import software.amazon.awssdk.utils.Validate; + +/** + * Unit Tests for {@link software.amazon.awssdk.services.neptune.internal.RdsPresignInterceptor} + */ /** * Unit Tests for {@link RdsPresignInterceptor} */ -public class PresignRequestHandlerTest { - private static final AwsBasicCredentials CREDENTIALS = AwsBasicCredentials.create("foo", "bar"); - private static final Region DESTINATION_REGION = Region.of("us-west-2"); +class PresignRequestHandlerTest { + private static String TEST_KMS_KEY_ID = "arn:aws:kms:us-west-2:123456789012:key/" + + "11111111-2222-3333-4444-555555555555"; - private static final RdsPresignInterceptor presignInterceptor = new CopyDbClusterSnapshotPresignInterceptor(); - private final CopyDbClusterSnapshotRequestMarshaller marshaller = - new CopyDbClusterSnapshotRequestMarshaller(RdsPresignInterceptor.PROTOCOL_FACTORY); + @ParameterizedTest + @MethodSource("testCases") + public void testExpectations(TestCase testCase) { + // Arrange + CapturingInterceptor interceptor = new CapturingInterceptor(); + NeptuneClientBuilder clientBuilder = client(interceptor, testCase.signingClockOverride); + testCase.clientConfigure.accept(clientBuilder); + NeptuneClient client = clientBuilder.build(); - @Test - public void testSetsPresignedUrl() { - CopyDbClusterSnapshotRequest request = makeTestRequest(); - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + // Act + assertThatThrownBy(() -> testCase.clientConsumer.accept(client)) + .hasMessageContaining("boom!"); - assertNotNull(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - } + // Assert + SdkHttpFullRequest request = (SdkHttpFullRequest) interceptor.httpRequest(); + Map> rawQueryParameters = rawQueryParameters(request); - @Test - public void testComputesPresignedUrlCorrectlyForCopyDbClusterSnapshotRequest() { - // Note: test data was baselined by performing actual calls, with real - // credentials to RDS and checking that they succeeded. Then the - // request was recreated with all the same parameters but with test - // credentials. - final CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); - - Calendar c = new GregorianCalendar(); - c.setTimeZone(TimeZone.getTimeZone("UTC")); - // 20161221T180735Z - // Note: month is 0-based - c.set(2016, Calendar.DECEMBER, 21, 18, 7, 35); - - Clock signingDateOverride = Mockito.mock(Clock.class); - when(signingDateOverride.millis()).thenReturn(c.getTimeInMillis()); - - RdsPresignInterceptor interceptor = new CopyDbClusterSnapshotPresignInterceptor(signingDateOverride); - - SdkHttpRequest presignedRequest = modifyHttpRequest(interceptor, request, marshallRequest(request)); - - final String expectedPreSignedUrl = "https://rds.us-east-1.amazonaws.com?" + - "Action=CopyDBClusterSnapshot" + - "&Version=2014-10-31" + - "&SourceDBClusterSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + - "&TargetDBClusterSnapshotIdentifier=test-instance-ss-copy-2" + - "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333-4444-555555555555" + - "&DestinationRegion=us-west-2" + - "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + - "&X-Amz-Date=20161221T180735Z" + - "&X-Amz-SignedHeaders=host" + - "&X-Amz-Expires=604800" + - "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + - "&X-Amz-Signature=00822ebbba95e2e6ac09112aa85621fbef060a596e3e1480f9f4ac61493e9821"; - assertEquals(expectedPreSignedUrl, presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - } + // The following params should not be included in the outgoing request + assertFalse(rawQueryParameters.containsKey("SourceRegion")); + assertFalse(rawQueryParameters.containsKey("DestinationRegion")); - @Test - public void testSkipsPresigningIfUrlSet() { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceRegion("us-west-2") - .preSignedUrl("PRESIGNED") - .build(); + if (testCase.shouldContainPreSignedUrl) { + List rawPresignedUrlValue = rawQueryParameters.get("PreSignedUrl"); + assertNotNull(rawPresignedUrlValue); + assertTrue(rawPresignedUrlValue.size() == 1); + String presignedUrl = rawPresignedUrlValue.get(0); + assertNotNull(presignedUrl); + // Validate that the URL can be parsed back + URI presignedUrlAsUri = URI.create(presignedUrl); + assertNotNull(presignedUrlAsUri); + if (testCase.expectedDestinationRegion != null) { + assertTrue(presignedUrl.contains("DestinationRegion=" + testCase.expectedDestinationRegion)); + } + if (testCase.expectedUri != null) { + assertEquals(normalize(URI.create(testCase.expectedUri)), normalize(presignedUrlAsUri)); + } + } else { + assertFalse(rawQueryParameters.containsKey("PreSignedUrl")); + } + } + public static List testCases() { + return Arrays.asList( + builder("CopyDbClusterSnapshot - Sets pre-signed URL when sourceRegion is set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("CopyDbClusterSnapshot - Doesn't set pre-signed URL when sourceRegion is NOT set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder().build())) + .shouldContainPreSignedUrl(false) + .build(), + builder("CopyDbClusterSnapshot - Does not override pre-signed URL") + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-west-2") + .preSignedUrl("http://localhost?foo=bar") + .build())) + .shouldContainPreSignedUrl(true) + .expectedUri("http://localhost?foo=bar") + .build(), + builder("CopyDbClusterSnapshot - Fixed time") + .clientConfigure(c -> c.region(Region.US_WEST_2)) + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .signingClockOverride(Clock.fixed(Instant.parse("2016-12-21T18:07:35.000Z"), ZoneId.of("UTC"))) + .expectedUri(fixedTimePresignedUrl()) + .build(), - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + builder("createDBCluster With SourceRegion Sends Presigned Url") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID) + .sourceRegion("us-west-2"))) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("createDBCluster Without SourceRegion Does NOT Send PresignedUrl") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID))) + .shouldContainPreSignedUrl(false) + .build() + ); + } - assertEquals("PRESIGNED", presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); + private static CopyDbClusterSnapshotRequest.Builder makeTestRequestBuilder() { + return CopyDbClusterSnapshotRequest + .builder() + .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds" + + ":test-instance-ss-2016-12-20-23-19") + .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") + .kmsKeyId(TEST_KMS_KEY_ID); } - @Test - public void testSkipsPresigningIfSourceRegionNotSet() { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder().build(); + private static NeptuneClientBuilder client(CapturingInterceptor interceptor, Clock signingClockOverride) { + NeptuneClientBuilder builder = NeptuneClient + .builder() + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("foo", "bar"))) + .region(Region.US_EAST_1) + .addPlugin(c -> { + // Adds the capturing interceptor. + NeptuneServiceClientConfiguration.Builder config = + Validate.isInstanceOf(NeptuneServiceClientConfiguration.Builder.class, c, + "\uD83E\uDD14"); + config.overrideConfiguration(oc -> oc.addExecutionInterceptor(interceptor)); + }); - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + if (signingClockOverride != null) { + // Adds a auth scheme wrapper that handles the clock override + builder.addPlugin(c -> { + NeptuneServiceClientConfiguration.Builder config = + Validate.isInstanceOf(NeptuneServiceClientConfiguration.Builder.class, c, "\uD83E\uDD14"); + config.authSchemeProvider(clockOverridingAuthScheme(config.authSchemeProvider(), signingClockOverride)); + }); + } + return builder; + } - assertNull(presignedRequest.rawQueryParameters().get("PreSignedUrl")); + private static NeptuneAuthSchemeProvider clockOverridingAuthScheme(NeptuneAuthSchemeProvider source, + Clock signingClockOverride) { + return authSchemeParams -> { + List authSchemeOptions = source.resolveAuthScheme(authSchemeParams); + List result = new ArrayList<>(authSchemeOptions.size()); + for (AuthSchemeOption option : authSchemeOptions) { + if (option.schemeId().equals(AwsV4AuthScheme.SCHEME_ID)) { + option = option.toBuilder() + .putSignerProperty(AwsV4FamilyHttpSigner.SIGNING_CLOCK, signingClockOverride) + .build(); + } + result.add(option); + } + return result; + }; } - @Test - public void testParsesDestinationRegionfromRequestEndpoint() throws URISyntaxException { - CopyDbClusterSnapshotRequest request = CopyDbClusterSnapshotRequest.builder() - .sourceRegion("us-east-1") - .build(); - Region destination = Region.of("us-west-2"); - SdkHttpFullRequest marshalled = marshallRequest(request); + static String fixedTimePresignedUrl() { + return + "https://rds.us-east-1.amazonaws.com?" + + "Action=CopyDBClusterSnapshot" + + "&Version=2014-10-31" + + "&SourceDBClusterSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012" + + "%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + + "&TargetDBClusterSnapshotIdentifier=test-instance-ss-copy-2" + + "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333" + + "-4444-555555555555" + + "&DestinationRegion=us-west-2" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Date=20161221T180735Z" + + "&X-Amz-SignedHeaders=host" + + "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + + "&X-Amz-Expires=604800" + + "&X-Amz-Signature=00822ebbba95e2e6ac09112aa85621fbef060a596e3e1480f9f4ac61493e9821"; + } - final SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshalled); + private Map> rawQueryParameters(SdkHttpFullRequest request) { + // Retrieve back from the query parameters from the body, this is best-effort only. + try { + String decodedQueryParams = IoUtils.toUtf8String(request.contentStreamProvider().get().newStream()); + String[] keyValuePairs = decodedQueryParams.split("&"); + Map> result = new LinkedHashMap<>(); + for (String keyValuePair : keyValuePairs) { + String[] kvpParts = keyValuePair.split("=", 2); + String value = URLDecoder.decode(kvpParts.length > 1 ? kvpParts[1] : "", StandardCharsets.UTF_8.name()); + result.computeIfAbsent(kvpParts[0], x -> new ArrayList<>()).add(value); + } + return result; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } - final URI presignedUrl = new URI(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - assertTrue(presignedUrl.toString().contains("DestinationRegion=" + destination.id())); + static TestCaseBuilder builder(String name) { + return new TestCaseBuilder() + .clientConfigure(c -> { + }) + .name(name); } - @Test - public void testSourceRegionRemovedFromOriginalRequest() { - CopyDbClusterSnapshotRequest request = makeTestRequest(); - SdkHttpFullRequest marshalled = marshallRequest(request); - SdkHttpRequest actual = modifyHttpRequest(presignInterceptor, request, marshalled); + private static String normalize(URI uri) { + String uriAsString = uri.toString(); + int queryStart = uriAsString.indexOf('?'); + if (queryStart == -1) { + return uriAsString; + } + String uriQueryPrefix = uriAsString.substring(0, queryStart); + String query = uri.getQuery(); + if (query == null) { + return uriAsString; + } + if (!query.isEmpty()) { + String[] queryParts = query.split("&"); + query = Arrays.stream(queryParts) + .sorted() + .collect(Collectors.joining("&")); - assertFalse(actual.rawQueryParameters().containsKey("SourceRegion")); + } + return uriQueryPrefix + "?" + query; } - private SdkHttpFullRequest marshallRequest(CopyDbClusterSnapshotRequest request) { - SdkHttpFullRequest.Builder marshalled = marshaller.marshall(request).toBuilder(); + static class TestCase { + private final String name; + private final Consumer clientConfigure; + private final Consumer clientConsumer; + private final Boolean shouldContainPreSignedUrl; + private final String expectedDestinationRegion; + private final Clock signingClockOverride; + private final String expectedUri; - URI endpoint = new DefaultServiceEndpointBuilder("rds", Protocol.HTTPS.toString()) - .withRegion(DESTINATION_REGION) - .getServiceEndpoint(); - return marshalled.uri(endpoint).build(); + TestCase(TestCaseBuilder builder) { + this.name = Validate.notNull(builder.name, "name"); + this.clientConsumer = Validate.notNull(builder.clientConsumer, "clientConsumer"); + this.clientConfigure = Validate.notNull(builder.clientConfigure, "clientConfigure"); + this.shouldContainPreSignedUrl = builder.shouldContainPreSignedUrl; + this.expectedDestinationRegion = builder.expectedDestinationRegion; + this.signingClockOverride = builder.signingClockOverride; + this.expectedUri = builder.expectedUri; + } } - private ExecutionAttributes executionAttributes() { - return new ExecutionAttributes().putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, CREDENTIALS) - .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, DESTINATION_REGION) - .putAttribute(SdkExecutionAttribute.PROFILE_FILE_SUPPLIER, - ProfileFile::defaultProfileFile) - .putAttribute(SdkExecutionAttribute.PROFILE_NAME, "default"); - } + static class TestCaseBuilder { + private String name; + private Consumer clientConfigure; + private Consumer clientConsumer; + private Boolean shouldContainPreSignedUrl; + private String expectedDestinationRegion; + private Clock signingClockOverride; + private String expectedUri; - private CopyDbClusterSnapshotRequest makeTestRequest() { - return CopyDbClusterSnapshotRequest.builder() - .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); + private TestCaseBuilder name(String name) { + this.name = name; + return this; + } + + private TestCaseBuilder clientConfigure(Consumer clientConfigure) { + this.clientConfigure = clientConfigure; + return this; + } + + private TestCaseBuilder clientConsumer(Consumer clientConsumer) { + this.clientConsumer = clientConsumer; + return this; + } + + private TestCaseBuilder shouldContainPreSignedUrl(Boolean value) { + this.shouldContainPreSignedUrl = value; + return this; + } + + private TestCaseBuilder expectedDestinationRegion(String value) { + this.expectedDestinationRegion = value; + return this; + } + + public TestCaseBuilder signingClockOverride(Clock signingClockOverride) { + this.signingClockOverride = signingClockOverride; + return this; + } + + public TestCaseBuilder expectedUri(String expectedUri) { + this.expectedUri = expectedUri; + return this; + } + + public TestCase build() { + return new TestCase(this); + } } - private SdkHttpRequest modifyHttpRequest(ExecutionInterceptor interceptor, - NeptuneRequest request, - SdkHttpFullRequest httpRequest) { - InterceptorContext context = InterceptorContext.builder().request(request).httpRequest(httpRequest).build(); - return interceptor.modifyHttpRequest(context, executionAttributes()); + static class CapturingInterceptor implements ExecutionInterceptor { + private static final RuntimeException BOOM = new RuntimeException("boom!"); + private Context.BeforeTransmission context; + private ExecutionAttributes executionAttributes; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.context = context; + this.executionAttributes = executionAttributes; + throw BOOM; + } + + public ExecutionAttributes executionAttributes() { + return executionAttributes; + } + + public SdkHttpRequest httpRequest() { + return context.httpRequest(); + } } } diff --git a/services/rds/pom.xml b/services/rds/pom.xml index 7aacb29324b6..ee5effef0b24 100644 --- a/services/rds/pom.xml +++ b/services/rds/pom.xml @@ -56,11 +56,6 @@ protocol-core ${awsjavasdk.version} - - software.amazon.awssdk - profiles - ${awsjavasdk.version} - software.amazon.awssdk http-auth-aws diff --git a/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java b/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java index 46134f91dfbb..eae144050d6d 100644 --- a/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java +++ b/services/rds/src/main/java/software/amazon/awssdk/services/rds/internal/RdsPresignInterceptor.java @@ -15,21 +15,20 @@ package software.amazon.awssdk.services.rds.internal; -import static software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute.AWS_CREDENTIALS; import static software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME; import java.net.URI; import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; +import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.auth.credentials.AwsCredentials; -import software.amazon.awssdk.auth.credentials.CredentialUtils; -import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; import software.amazon.awssdk.core.Protocol; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.exception.SdkClientException; @@ -40,7 +39,13 @@ import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.http.auth.spi.signer.SignRequest; +import software.amazon.awssdk.http.auth.spi.signer.SignedRequest; +import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.rds.model.RdsRequest; @@ -79,49 +84,39 @@ public interface PresignableRequest { private final Class requestClassToPreSign; - private final Clock signingOverrideClock; + private final Clock signingClockOverride; - public RdsPresignInterceptor(Class requestClassToPreSign) { + protected RdsPresignInterceptor(Class requestClassToPreSign) { this(requestClassToPreSign, null); } - public RdsPresignInterceptor(Class requestClassToPreSign, Clock signingOverrideClock) { + protected RdsPresignInterceptor(Class requestClassToPreSign, Clock signingClockOverride) { this.requestClassToPreSign = requestClassToPreSign; - this.signingOverrideClock = signingOverrideClock; + this.signingClockOverride = signingClockOverride; } @Override public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, - ExecutionAttributes executionAttributes) { + ExecutionAttributes executionAttributes) { SdkHttpRequest request = context.httpRequest(); - SdkRequest originalRequest = context.request(); - if (!requestClassToPreSign.isInstance(originalRequest)) { - return request; - } - - if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { - return request; + PresignableRequest presignableRequest = toPresignableRequest(request, context); + if (presignableRequest == null) { + return request.toBuilder().removeQueryParameter(PARAM_SOURCE_REGION).build(); } - PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); - + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); String sourceRegion = presignableRequest.getSourceRegion(); - if (sourceRegion == null) { - return request; - } - - String destinationRegion = executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id(); - + String destinationRegion = selectedAuthScheme.authSchemeOption().signerProperty(AwsV4HttpSigner.REGION_NAME); URI endpoint = createEndpoint(sourceRegion, SERVICE_NAME, executionAttributes); SdkHttpFullRequest.Builder marshalledRequest = presignableRequest.marshall().toBuilder().uri(endpoint); SdkHttpFullRequest requestToPresign = - marshalledRequest.method(SdkHttpMethod.GET) - .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) - .removeQueryParameter(PARAM_SOURCE_REGION) - .build(); + marshalledRequest.method(SdkHttpMethod.GET) + .putRawQueryParameter(PARAM_DESTINATION_REGION, destinationRegion) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); - requestToPresign = presignRequest(requestToPresign, executionAttributes, sourceRegion); + requestToPresign = sraPresignRequest(executionAttributes, requestToPresign, sourceRegion); String presignedUrl = requestToPresign.getUri().toString(); @@ -140,39 +135,92 @@ public final SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, */ protected abstract PresignableRequest adaptRequest(T originalRequest); - private SdkHttpFullRequest presignRequest(SdkHttpFullRequest request, - ExecutionAttributes attributes, - String signingRegion) { + /** + * Converts the request to a PresignableRequest if possible. + */ + private PresignableRequest toPresignableRequest(SdkHttpRequest request, Context.ModifyHttpRequest context) { + SdkRequest originalRequest = context.request(); + if (!requestClassToPreSign.isInstance(originalRequest)) { + return null; + } + if (request.firstMatchingRawQueryParameter(PARAM_PRESIGNED_URL).isPresent()) { + return null; + } + + PresignableRequest presignableRequest = adaptRequest(requestClassToPreSign.cast(originalRequest)); + String sourceRegion = presignableRequest.getSourceRegion(); + if (sourceRegion == null) { + return null; + } + return presignableRequest; + } - Aws4Signer signer = Aws4Signer.create(); - Aws4PresignerParams presignerParams = Aws4PresignerParams.builder() - .signingRegion(Region.of(signingRegion)) - .signingName(SERVICE_NAME) - .signingClockOverride(signingOverrideClock) - .awsCredentials(resolveCredentials(attributes)) - .build(); + /** + * Presign the provided HTTP request using SRA HttpSigner + */ + private SdkHttpFullRequest sraPresignRequest(ExecutionAttributes executionAttributes, SdkHttpFullRequest request, + String signingRegion) { + SelectedAuthScheme selectedAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + Instant signingInstant; + if (signingClockOverride != null) { + signingInstant = signingClockOverride.instant(); + } else { + signingInstant = Instant.now(); + } + // A fixed signing clock is used so that the current time used by the signing logic, as well as to + // determine expiration are the same. + Clock signingClock = Clock.fixed(signingInstant, ZoneOffset.UTC); + Duration expirationDuration = Duration.ofDays(7); + return doSraPresign(request, selectedAuthScheme, signingRegion, signingClock, expirationDuration); + } - return signer.presign(request, presignerParams); + private SdkHttpFullRequest doSraPresign(SdkHttpFullRequest request, + SelectedAuthScheme selectedAuthScheme, + String signingRegion, + Clock signingClock, + Duration expirationDuration) { + CompletableFuture identityFuture = selectedAuthScheme.identity(); + T identity = CompletableFutureUtils.joinLikeSync(identityFuture); + + // Pre-signed URL puts auth info in query string, does not sign the payload, and has an expiry. + SignRequest.Builder signRequestBuilder = SignRequest + .builder(identity) + .putProperty(AwsV4FamilyHttpSigner.AUTH_LOCATION, AwsV4FamilyHttpSigner.AuthLocation.QUERY_STRING) + .putProperty(AwsV4FamilyHttpSigner.EXPIRATION_DURATION, expirationDuration) + .putProperty(HttpSigner.SIGNING_CLOCK, signingClock) + .request(request) + .payload(request.contentStreamProvider().orElse(null)); + AuthSchemeOption authSchemeOption = selectedAuthScheme.authSchemeOption(); + authSchemeOption.forEachSignerProperty(signRequestBuilder::putProperty); + // Override the region + signRequestBuilder.putProperty(AwsV4HttpSigner.REGION_NAME, signingRegion); + HttpSigner signer = selectedAuthScheme.signer(); + SignedRequest signedRequest = signer.sign(signRequestBuilder.build()); + return toSdkHttpFullRequest(signedRequest); } - private AwsCredentials resolveCredentials(ExecutionAttributes attributes) { - return attributes.getOptionalAttribute(SELECTED_AUTH_SCHEME) - .map(selectedAuthScheme -> selectedAuthScheme.identity()) - .map(identityFuture -> CompletableFutureUtils.joinLikeSync(identityFuture)) - .filter(identity -> identity instanceof AwsCredentialsIdentity) - .map(identity -> { - AwsCredentialsIdentity awsCredentialsIdentity = (AwsCredentialsIdentity) identity; - return CredentialUtils.toCredentials(awsCredentialsIdentity); - }).orElse(attributes.getAttribute(AWS_CREDENTIALS)); + private SdkHttpFullRequest toSdkHttpFullRequest(SignedRequest signedRequest) { + SdkHttpRequest request = signedRequest.request(); + + return SdkHttpFullRequest.builder() + .contentStreamProvider(signedRequest.payload().orElse(null)) + .protocol(request.protocol()) + .method(request.method()) + .host(request.host()) + .port(request.port()) + .encodedPath(request.encodedPath()) + .applyMutation(r -> request.forEachHeader(r::putHeader)) + .applyMutation(r -> request.forEachRawQueryParameter(r::putRawQueryParameter)) + .removeQueryParameter(PARAM_SOURCE_REGION) + .build(); } private URI createEndpoint(String regionName, String serviceName, ExecutionAttributes attributes) { Region region = Region.of(regionName); - if (region == null) { throw SdkClientException.builder() .message("{" + serviceName + ", " + regionName + "} was not " - + "found in region metadata. Update to latest version of SDK and try again.") + + "found in region metadata. Update to latest version of SDK and try again.") .build(); } diff --git a/services/rds/src/main/resources/codegen-resources/customization.config b/services/rds/src/main/resources/codegen-resources/customization.config index fa6c19c93539..ae3bb5b30401 100644 --- a/services/rds/src/main/resources/codegen-resources/customization.config +++ b/services/rds/src/main/resources/codegen-resources/customization.config @@ -1,4 +1,6 @@ { + "useSraAuth": true, + "enableGenerateCompiledEndpointRules": true, "shapeModifiers" : { "CopyDBSnapshotMessage" : { "inject" : [ diff --git a/services/rds/src/test/java/software/amazon/awssdk/services/rds/internal/PresignRequestHandlerTest.java b/services/rds/src/test/java/software/amazon/awssdk/services/rds/internal/PresignRequestHandlerTest.java index 65a04d3a2f2a..fff1e649fc40 100644 --- a/services/rds/src/test/java/software/amazon/awssdk/services/rds/internal/PresignRequestHandlerTest.java +++ b/services/rds/src/test/java/software/amazon/awssdk/services/rds/internal/PresignRequestHandlerTest.java @@ -15,174 +15,366 @@ package software.amazon.awssdk.services.rds.internal; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.UncheckedIOException; import java.net.URI; -import java.net.URISyntaxException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; import java.time.Clock; -import java.util.Calendar; -import java.util.GregorianCalendar; -import java.util.TimeZone; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; -import software.amazon.awssdk.awscore.endpoint.DefaultServiceEndpointBuilder; -import software.amazon.awssdk.core.Protocol; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; -import software.amazon.awssdk.core.interceptor.InterceptorContext; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.rds.model.CopyDbSnapshotRequest; -import software.amazon.awssdk.services.rds.model.RdsRequest; -import software.amazon.awssdk.services.rds.transform.CopyDbSnapshotRequestMarshaller; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.RdsClientBuilder; +import software.amazon.awssdk.services.rds.RdsServiceClientConfiguration; +import software.amazon.awssdk.services.rds.auth.scheme.RdsAuthSchemeProvider; +import software.amazon.awssdk.services.rds.model.CopyDbClusterSnapshotRequest; +import software.amazon.awssdk.utils.IoUtils; +import software.amazon.awssdk.utils.Validate; /** * Unit Tests for {@link RdsPresignInterceptor} */ -public class PresignRequestHandlerTest { - private static final AwsBasicCredentials CREDENTIALS = AwsBasicCredentials.create("foo", "bar"); - private static final Region DESTINATION_REGION = Region.of("us-west-2"); +class PresignRequestHandlerTest { + private static String TEST_KMS_KEY_ID = "arn:aws:kms:us-west-2:123456789012:key/" + + "11111111-2222-3333-4444-555555555555"; - private static RdsPresignInterceptor presignInterceptor = new CopyDbSnapshotPresignInterceptor(); - private final CopyDbSnapshotRequestMarshaller marshaller = - new CopyDbSnapshotRequestMarshaller(RdsPresignInterceptor.PROTOCOL_FACTORY); + @ParameterizedTest + @MethodSource("testCases") + public void testExpectations(TestCase testCase) { + // Arrange + CapturingInterceptor interceptor = new CapturingInterceptor(); + RdsClientBuilder clientBuilder = client(interceptor, testCase.signingClockOverride); + testCase.clientConfigure.accept(clientBuilder); + RdsClient client = clientBuilder.build(); - @Test - public void testSetsPresignedUrl() { - CopyDbSnapshotRequest request = makeTestRequest(); - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + // Act + assertThatThrownBy(() -> testCase.clientConsumer.accept(client)) + .hasMessageContaining("boom!"); - assertNotNull(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - } + // Assert + SdkHttpFullRequest request = (SdkHttpFullRequest) interceptor.httpRequest(); + Map> rawQueryParameters = rawQueryParameters(request); + + // The following params should not be included in the outgoing request + assertFalse(rawQueryParameters.containsKey("SourceRegion")); + assertFalse(rawQueryParameters.containsKey("DestinationRegion")); - @Test - public void testComputesPresignedUrlCorrectly() { - // Note: test data was baselined by performing actual calls, with real - // credentials to RDS and checking that they succeeded. Then the - // request was recreated with all the same parameters but with test - // credentials. - final CopyDbSnapshotRequest request = CopyDbSnapshotRequest.builder() - .sourceDBSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); - - Calendar c = new GregorianCalendar(); - c.setTimeZone(TimeZone.getTimeZone("UTC")); - // 20161221T180735Z - // Note: month is 0-based - c.set(2016, Calendar.DECEMBER, 21, 18, 7, 35); - - Clock signingDateOverride = Mockito.mock(Clock.class); - when(signingDateOverride.millis()).thenReturn(c.getTimeInMillis()); - - RdsPresignInterceptor interceptor = new CopyDbSnapshotPresignInterceptor(signingDateOverride); - - SdkHttpRequest presignedRequest = modifyHttpRequest(interceptor, request, marshallRequest(request)); - - final String expectedPreSignedUrl = "https://rds.us-east-1.amazonaws.com?" + - "Action=CopyDBSnapshot" + - "&Version=2014-10-31" + - "&SourceDBSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + - "&TargetDBSnapshotIdentifier=test-instance-ss-copy-2" + - "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333-4444-555555555555" + - "&DestinationRegion=us-west-2" + - "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + - "&X-Amz-Date=20161221T180735Z" + - "&X-Amz-SignedHeaders=host" + - "&X-Amz-Expires=604800" + - "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + - "&X-Amz-Signature=f839ca3c728dc96e7c978befeac648296b9f778f6724073de4217173859d13d9"; - - assertEquals(expectedPreSignedUrl, presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); + if (testCase.shouldContainPreSignedUrl) { + List rawPresignedUrlValue = rawQueryParameters.get("PreSignedUrl"); + assertNotNull(rawPresignedUrlValue); + assertTrue(rawPresignedUrlValue.size() == 1); + String presignedUrl = rawPresignedUrlValue.get(0); + assertNotNull(presignedUrl); + // Validate that the URL can be parsed back + URI presignedUrlAsUri = URI.create(presignedUrl); + assertNotNull(presignedUrlAsUri); + if (testCase.expectedDestinationRegion != null) { + assertTrue(presignedUrl.contains("DestinationRegion=" + testCase.expectedDestinationRegion)); + } + if (testCase.expectedUri != null) { + assertEquals(normalize(URI.create(testCase.expectedUri)), normalize(presignedUrlAsUri)); + } + } else { + assertFalse(rawQueryParameters.containsKey("PreSignedUrl")); + } } - @Test - public void testSkipsPresigningIfUrlSet() { - CopyDbSnapshotRequest request = CopyDbSnapshotRequest.builder() - .sourceRegion("us-west-2") - .preSignedUrl("PRESIGNED") - .build(); + public static List testCases() { + return Arrays.asList( + builder("CopyDbClusterSnapshot - Sets pre-signed URL when sourceRegion is set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("CopyDbClusterSnapshot - Doesn't set pre-signed URL when sourceRegion is NOT set") + .clientConsumer(c -> c.copyDBClusterSnapshot(makeTestRequestBuilder().build())) + .shouldContainPreSignedUrl(false) + .build(), + builder("CopyDbClusterSnapshot - Does not override pre-signed URL") + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-west-2") + .preSignedUrl("http://localhost?foo=bar") + .build())) + .shouldContainPreSignedUrl(true) + .expectedUri("http://localhost?foo=bar") + .build(), + builder("CopyDbClusterSnapshot - Fixed time") + .clientConfigure(c -> c.region(Region.US_WEST_2)) + .clientConsumer(c -> c.copyDBClusterSnapshot( + makeTestRequestBuilder() + .sourceRegion("us-east-1") + .build())) + .shouldContainPreSignedUrl(true) + .signingClockOverride(Clock.fixed(Instant.parse("2016-12-21T18:07:35.000Z"), ZoneId.of("UTC"))) + .expectedUri(fixedTimePresignedUrl()) + .build(), + builder("CreateDBCluster With SourceRegion Sends Presigned Url") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID) + .sourceRegion("us-west-2"))) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("CreateDBCluster Without SourceRegion Does NOT Send PresignedUrl") + .clientConsumer(c -> c.createDBCluster(r -> r.kmsKeyId(TEST_KMS_KEY_ID))) + .shouldContainPreSignedUrl(false) + .build(), - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + builder("CreateDBInstanceReadReplica With SourceRegion Sends Presigned Url") + .clientConsumer(c -> c.createDBInstanceReadReplica(r -> r.kmsKeyId(TEST_KMS_KEY_ID) + .sourceRegion("us-west-2"))) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("CreateDBInstanceReadReplica Without SourceRegion Does NOT Send PresignedUrl") + .clientConsumer(c -> c.createDBInstanceReadReplica(r -> r.kmsKeyId(TEST_KMS_KEY_ID))) + .shouldContainPreSignedUrl(false) + .build(), - assertEquals("PRESIGNED", presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); + builder("StartDBInstanceAutomatedBackupsReplication With SourceRegion Sends Presigned Url") + .clientConsumer(c -> c.startDBInstanceAutomatedBackupsReplication(r -> r.kmsKeyId(TEST_KMS_KEY_ID) + .sourceRegion("us-west-2"))) + .shouldContainPreSignedUrl(true) + .expectedDestinationRegion("us-east-1") + .build(), + builder("StartDBInstanceAutomatedBackupsReplication Without SourceRegion Does NOT Send PresignedUrl") + .clientConsumer(c -> c.startDBInstanceAutomatedBackupsReplication(r -> r.kmsKeyId(TEST_KMS_KEY_ID))) + .shouldContainPreSignedUrl(false) + .build() + ); } - @Test - public void testSkipsPresigningIfSourceRegionNotSet() { - CopyDbSnapshotRequest request = CopyDbSnapshotRequest.builder().build(); + private static CopyDbClusterSnapshotRequest.Builder makeTestRequestBuilder() { + return CopyDbClusterSnapshotRequest + .builder() + .sourceDBClusterSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds" + + ":test-instance-ss-2016-12-20-23-19") + .targetDBClusterSnapshotIdentifier("test-instance-ss-copy-2") + .kmsKeyId(TEST_KMS_KEY_ID); + } - SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshallRequest(request)); + private static RdsClientBuilder client(CapturingInterceptor interceptor, Clock signingClockOverride) { + RdsClientBuilder builder = RdsClient + .builder() + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("foo", "bar"))) + .region(Region.US_EAST_1) + .addPlugin(c -> { + // Adds the capturing interceptor. + RdsServiceClientConfiguration.Builder config = + Validate.isInstanceOf(RdsServiceClientConfiguration.Builder.class, c, + "\uD83E\uDD14"); + config.overrideConfiguration(oc -> oc.addExecutionInterceptor(interceptor)); + }); - assertNull(presignedRequest.rawQueryParameters().get("PreSignedUrl")); + if (signingClockOverride != null) { + // Adds a auth scheme wrapper that handles the clock override + builder.addPlugin(c -> { + RdsServiceClientConfiguration.Builder config = + Validate.isInstanceOf(RdsServiceClientConfiguration.Builder.class, c, "\uD83E\uDD14"); + config.authSchemeProvider(clockOverridingAuthScheme(config.authSchemeProvider(), signingClockOverride)); + }); + } + return builder; } - @Test - public void testParsesDestinationRegionfromRequestEndpoint() throws URISyntaxException { - CopyDbSnapshotRequest request = CopyDbSnapshotRequest.builder() - .sourceRegion("us-east-1") - .build(); - Region destination = Region.of("us-west-2"); - SdkHttpFullRequest marshalled = marshallRequest(request); - - final SdkHttpRequest presignedRequest = modifyHttpRequest(presignInterceptor, request, marshalled); + private static RdsAuthSchemeProvider clockOverridingAuthScheme(RdsAuthSchemeProvider source, Clock signingClockOverride) { + return authSchemeParams -> { + List authSchemeOptions = source.resolveAuthScheme(authSchemeParams); + List result = new ArrayList<>(authSchemeOptions.size()); + for (AuthSchemeOption option : authSchemeOptions) { + if (option.schemeId().equals(AwsV4AuthScheme.SCHEME_ID)) { + option = option.toBuilder() + .putSignerProperty(AwsV4FamilyHttpSigner.SIGNING_CLOCK, signingClockOverride) + .build(); + } + result.add(option); + } + return result; + }; + } - final URI presignedUrl = new URI(presignedRequest.rawQueryParameters().get("PreSignedUrl").get(0)); - assertTrue(presignedUrl.toString().contains("DestinationRegion=" + destination.id())); + static String fixedTimePresignedUrl() { + return + "https://rds.us-east-1.amazonaws.com?" + + "Action=CopyDBClusterSnapshot" + + "&Version=2014-10-31" + + "&SourceDBClusterSnapshotIdentifier=arn%3Aaws%3Ards%3Aus-east-1%3A123456789012" + + "%3Asnapshot%3Ards%3Atest-instance-ss-2016-12-20-23-19" + + "&TargetDBClusterSnapshotIdentifier=test-instance-ss-copy-2" + + "&KmsKeyId=arn%3Aaws%3Akms%3Aus-west-2%3A123456789012%3Akey%2F11111111-2222-3333" + + "-4444-555555555555" + + "&DestinationRegion=us-west-2" + + "&X-Amz-Algorithm=AWS4-HMAC-SHA256" + + "&X-Amz-Date=20161221T180735Z" + + "&X-Amz-SignedHeaders=host" + + "&X-Amz-Credential=foo%2F20161221%2Fus-east-1%2Frds%2Faws4_request" + + "&X-Amz-Expires=604800" + + "&X-Amz-Signature=00822ebbba95e2e6ac09112aa85621fbef060a596e3e1480f9f4ac61493e9821"; } - @Test - public void testSourceRegionRemovedFromOriginalRequest() { - CopyDbSnapshotRequest request = makeTestRequest(); - SdkHttpFullRequest marshalled = marshallRequest(request); - SdkHttpRequest actual = modifyHttpRequest(presignInterceptor, request, marshalled); + private Map> rawQueryParameters(SdkHttpFullRequest request) { + // Retrieve back from the query parameters from the body, this is best-effort only. + try { + String decodedQueryParams = IoUtils.toUtf8String(request.contentStreamProvider().get().newStream()); + String[] keyValuePairs = decodedQueryParams.split("&"); + Map> result = new LinkedHashMap<>(); + for (String keyValuePair : keyValuePairs) { + String[] kvpParts = keyValuePair.split("=", 2); + String value = URLDecoder.decode(kvpParts.length > 1 ? kvpParts[1] : "", StandardCharsets.UTF_8.name()); + result.computeIfAbsent(kvpParts[0], x -> new ArrayList<>()).add(value); + } + return result; + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } - assertFalse(actual.rawQueryParameters().containsKey("SourceRegion")); + static TestCaseBuilder builder(String name) { + return new TestCaseBuilder() + .clientConfigure(c -> { + }) + .name(name); } - private SdkHttpFullRequest marshallRequest(CopyDbSnapshotRequest request) { - SdkHttpFullRequest.Builder marshalled = marshaller.marshall(request).toBuilder(); + private static String normalize(URI uri) { + String uriAsString = uri.toString(); + int queryStart = uriAsString.indexOf('?'); + if (queryStart == -1) { + return uriAsString; + } + String uriQueryPrefix = uriAsString.substring(0, queryStart); + String query = uri.getQuery(); + if (query == null) { + return uriAsString; + } + if (!query.isEmpty()) { + String[] queryParts = query.split("&"); + query = Arrays.stream(queryParts) + .sorted() + .collect(Collectors.joining("&")); - URI endpoint = new DefaultServiceEndpointBuilder("rds", Protocol.HTTPS.toString()) - .withRegion(DESTINATION_REGION) - .getServiceEndpoint(); - return marshalled.uri(endpoint).build(); + } + return uriQueryPrefix + "?" + query; } - private ExecutionAttributes executionAttributes() { - return new ExecutionAttributes().putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, CREDENTIALS) - .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, DESTINATION_REGION) - .putAttribute(SdkExecutionAttribute.PROFILE_FILE_SUPPLIER, - ProfileFile::defaultProfileFile) - .putAttribute(SdkExecutionAttribute.PROFILE_NAME, "default"); + static class TestCase { + private final String name; + private final Consumer clientConfigure; + private final Consumer clientConsumer; + private final Boolean shouldContainPreSignedUrl; + private final String expectedDestinationRegion; + private final Clock signingClockOverride; + private final String expectedUri; + + TestCase(TestCaseBuilder builder) { + this.name = Validate.notNull(builder.name, "name"); + this.clientConsumer = Validate.notNull(builder.clientConsumer, "clientConsumer"); + this.clientConfigure = Validate.notNull(builder.clientConfigure, "clientConfigure"); + this.shouldContainPreSignedUrl = builder.shouldContainPreSignedUrl; + this.expectedDestinationRegion = builder.expectedDestinationRegion; + this.signingClockOverride = builder.signingClockOverride; + this.expectedUri = builder.expectedUri; + } } - private CopyDbSnapshotRequest makeTestRequest() { - return CopyDbSnapshotRequest.builder() - .sourceDBSnapshotIdentifier("arn:aws:rds:us-east-1:123456789012:snapshot:rds:test-instance-ss-2016-12-20-23-19") - .targetDBSnapshotIdentifier("test-instance-ss-copy-2") - .sourceRegion("us-east-1") - .kmsKeyId("arn:aws:kms:us-west-2:123456789012:key/11111111-2222-3333-4444-555555555555") - .build(); + static class TestCaseBuilder { + private String name; + private Consumer clientConfigure; + private Consumer clientConsumer; + private Boolean shouldContainPreSignedUrl; + private String expectedDestinationRegion; + private Clock signingClockOverride; + private String expectedUri; + + private TestCaseBuilder name(String name) { + this.name = name; + return this; + } + + private TestCaseBuilder clientConfigure(Consumer clientConfigure) { + this.clientConfigure = clientConfigure; + return this; + } + + private TestCaseBuilder clientConsumer(Consumer clientConsumer) { + this.clientConsumer = clientConsumer; + return this; + } + + private TestCaseBuilder shouldContainPreSignedUrl(Boolean value) { + this.shouldContainPreSignedUrl = value; + return this; + } + + private TestCaseBuilder expectedDestinationRegion(String value) { + this.expectedDestinationRegion = value; + return this; + } + + public TestCaseBuilder signingClockOverride(Clock signingClockOverride) { + this.signingClockOverride = signingClockOverride; + return this; + } + + public TestCaseBuilder expectedUri(String expectedUri) { + this.expectedUri = expectedUri; + return this; + } + + public TestCase build() { + return new TestCase(this); + } } - private SdkHttpRequest modifyHttpRequest(ExecutionInterceptor interceptor, - RdsRequest request, - SdkHttpFullRequest httpRequest) { - InterceptorContext context = InterceptorContext.builder().request(request).httpRequest(httpRequest).build(); - return interceptor.modifyHttpRequest(context, executionAttributes()); + static class CapturingInterceptor implements ExecutionInterceptor { + private static final RuntimeException BOOM = new RuntimeException("boom!"); + private Context.BeforeTransmission context; + private ExecutionAttributes executionAttributes; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + this.context = context; + this.executionAttributes = executionAttributes; + throw BOOM; + } + + public ExecutionAttributes executionAttributes() { + return executionAttributes; + } + + public SdkHttpRequest httpRequest() { + return context.httpRequest(); + } } }