From 5200407c18a1b143a8cece2b7d74c77bea18fa67 Mon Sep 17 00:00:00 2001 From: Matthew Miller Date: Wed, 9 Feb 2022 13:18:55 -0800 Subject: [PATCH] Standardize the way SdkBytes are unmarshalled when they're part of a payload and the service returns no content. After this change, when SdkBytes are modeled as a payload, the SDK will always return a non-null empty SdkBytes. When SdkBytes are modeled as a field, they will be null if the service did not specify them. --- .../bugfix-AWSSDKforJavav2-98d2e80.json | 6 + bom-internal/pom.xml | 9 +- core/auth-crt/pom.xml | 2 +- core/auth/pom.xml | 2 +- core/aws-core/pom.xml | 2 +- core/metrics-spi/pom.xml | 2 +- .../unmarshall/JsonProtocolUnmarshaller.java | 12 +- .../unmarshall/QueryProtocolUnmarshaller.java | 24 ++ .../query/unmarshall/XmlDomParser.java | 11 +- .../unmarshall/XmlProtocolUnmarshaller.java | 13 +- .../unmarshall/XmlResponseParserUtils.java | 17 +- core/regions/pom.xml | 2 +- core/sdk-core/pom.xml | 2 +- http-clients/apache-client/pom.xml | 1 + http-clients/aws-crt-client/pom.xml | 1 + http-clients/netty-nio-client/pom.xml | 1 + metric-publishers/pom.xml | 2 +- pom.xml | 2 +- services-custom/dynamodb-enhanced/pom.xml | 2 +- services/pom.xml | 2 +- services/s3/pom.xml | 2 +- test/auth-sts-testing/pom.xml | 7 +- test/codegen-generated-classes-test/pom.xml | 2 +- .../BlobUnmarshallingTest.java | 249 ++++++++++++++++++ test/http-client-tests/pom.xml | 1 + test/protocol-tests-core/pom.xml | 2 +- .../suites/cases/rest-core-output.json | 1 + test/protocol-tests/pom.xml | 2 +- test/stability-tests/pom.xml | 2 +- .../awssdk/utils/LookaheadInputStream.java | 127 +++++++++ 30 files changed, 470 insertions(+), 40 deletions(-) create mode 100644 .changes/next-release/bugfix-AWSSDKforJavav2-98d2e80.json create mode 100644 test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/BlobUnmarshallingTest.java create mode 100644 utils/src/main/java/software/amazon/awssdk/utils/LookaheadInputStream.java diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-98d2e80.json b/.changes/next-release/bugfix-AWSSDKforJavav2-98d2e80.json new file mode 100644 index 000000000000..26f2341ebcf4 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-98d2e80.json @@ -0,0 +1,6 @@ +{ + "category": "AWS SDK for Java v2", + "contributor": "", + "type": "bugfix", + "description": "Always return an empty SDK bytes when services model their response payload as a blob. Previously, it would either return null, empty bytes or throw an exception depending on the protocol, HTTP client and whether the service was using chunked encoding for their responses." +} diff --git a/bom-internal/pom.xml b/bom-internal/pom.xml index 02aad44a54c7..978700e28ac8 100644 --- a/bom-internal/pom.xml +++ b/bom-internal/pom.xml @@ -292,10 +292,17 @@ com.github.tomakehurst - wiremock + wiremock-jre8 ${wiremock.version} test + + + com.github.tomakehurst + wiremock + 2.18.0 + compile + com.google.guava guava diff --git a/core/auth-crt/pom.xml b/core/auth-crt/pom.xml index 75dee13b9a40..fbcb6c4a9247 100644 --- a/core/auth-crt/pom.xml +++ b/core/auth-crt/pom.xml @@ -80,7 +80,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/core/auth/pom.xml b/core/auth/pom.xml index c25dc01020da..bc2b214f4734 100644 --- a/core/auth/pom.xml +++ b/core/auth/pom.xml @@ -84,7 +84,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/core/aws-core/pom.xml b/core/aws-core/pom.xml index 8b0242906fda..36d1606f9f94 100644 --- a/core/aws-core/pom.xml +++ b/core/aws-core/pom.xml @@ -111,7 +111,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/core/metrics-spi/pom.xml b/core/metrics-spi/pom.xml index 2a02330bd749..94c2919e8dd9 100644 --- a/core/metrics-spi/pom.xml +++ b/core/metrics-spi/pom.xml @@ -44,7 +44,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java index 274a9d3d247c..3ed0c9fb8cae 100644 --- a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java +++ b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.annotations.ThreadSafe; @@ -36,6 +37,7 @@ import software.amazon.awssdk.core.traits.MapTrait; import software.amazon.awssdk.core.traits.PayloadTrait; import software.amazon.awssdk.core.traits.TimestampFormatTrait; +import software.amazon.awssdk.http.AbortableInputStream; import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.protocols.core.StringToInstant; import software.amazon.awssdk.protocols.core.StringToValueConverter; @@ -225,9 +227,13 @@ private static TypeT unmarshallStructured(SdkPojo sdkPoj JsonNode jsonContent, JsonUnmarshallerContext context) { for (SdkField field : sdkPojo.sdkFields()) { - if (isExplicitPayloadMember(field) && field.marshallingType() == MarshallingType.SDK_BYTES && - context.response().content().isPresent()) { - field.set(sdkPojo, SdkBytes.fromInputStream(context.response().content().get())); + if (isExplicitPayloadMember(field) && field.marshallingType() == MarshallingType.SDK_BYTES) { + Optional responseContent = context.response().content(); + if (responseContent.isPresent()) { + field.set(sdkPojo, SdkBytes.fromInputStream(responseContent.get())); + } else { + field.set(sdkPojo, SdkBytes.fromByteArrayUnsafe(new byte[0])); + } } else { JsonNode jsonFieldContent = getJsonNode(jsonContent, field); JsonUnmarshaller unmarshaller = context.getUnmarshaller(field.location(), field.marshallingType()); diff --git a/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/internal/unmarshall/QueryProtocolUnmarshaller.java b/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/internal/unmarshall/QueryProtocolUnmarshaller.java index 9666a50d88cd..b781ff9ac7c5 100644 --- a/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/internal/unmarshall/QueryProtocolUnmarshaller.java +++ b/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/internal/unmarshall/QueryProtocolUnmarshaller.java @@ -17,14 +17,17 @@ import static software.amazon.awssdk.awscore.util.AwsHeader.AWS_REQUEST_ID; import static software.amazon.awssdk.protocols.query.internal.marshall.SimpleTypeQueryMarshaller.defaultTimestampFormats; +import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; import java.util.HashMap; import java.util.List; import java.util.Map; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.PayloadTrait; import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.protocols.core.StringToInstant; import software.amazon.awssdk.protocols.core.StringToValueConverter; @@ -32,6 +35,7 @@ import software.amazon.awssdk.protocols.query.unmarshall.XmlElement; import software.amazon.awssdk.protocols.query.unmarshall.XmlErrorUnmarshaller; import software.amazon.awssdk.utils.CollectionUtils; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.awssdk.utils.Pair; import software.amazon.awssdk.utils.builder.Buildable; @@ -69,11 +73,26 @@ private QueryProtocolUnmarshaller(Builder builder) { public Pair> unmarshall(SdkPojo sdkPojo, SdkHttpFullResponse response) { + if (responsePayloadIsBlob(sdkPojo)) { + XmlElement document = XmlElement.builder() + .textContent(response.content() + .map(s -> invokeSafely(() -> IoUtils.toUtf8String(s))) + .orElse("")) + .build(); + return Pair.of(unmarshall(sdkPojo, document, response), new HashMap<>()); + } + XmlElement document = response.content().map(XmlDomParser::parse).orElseGet(XmlElement::empty); XmlElement resultRoot = hasResultWrapper ? document.getFirstChild() : document; return Pair.of(unmarshall(sdkPojo, resultRoot, response), parseMetadata(document)); } + private boolean responsePayloadIsBlob(SdkPojo sdkPojo) { + return sdkPojo.sdkFields().stream() + .anyMatch(field -> field.marshallingType() == MarshallingType.SDK_BYTES && + field.containsTrait(PayloadTrait.class)); + } + /** * This method is also used to unmarshall exceptions. We use this since we've already parsed the XML * and the result root is in a different location depending on the protocol/service. @@ -109,6 +128,10 @@ private String metadataKeyName(XmlElement c) { private SdkPojo unmarshall(QueryUnmarshallerContext context, SdkPojo sdkPojo, XmlElement root) { if (root != null) { for (SdkField field : sdkPojo.sdkFields()) { + if (field.containsTrait(PayloadTrait.class) && field.marshallingType() == MarshallingType.SDK_BYTES) { + field.set(sdkPojo, SdkBytes.fromUtf8String(root.textContent())); + } + List element = root.getElementsByName(field.unmarshallLocationName()); if (!CollectionUtils.isNullOrEmpty(element)) { QueryUnmarshaller unmarshaller = @@ -118,6 +141,7 @@ private SdkPojo unmarshall(QueryUnmarshallerContext context, SdkPojo sdkPojo, Xm } } } + return (SdkPojo) ((Buildable) sdkPojo).build(); } diff --git a/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/unmarshall/XmlDomParser.java b/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/unmarshall/XmlDomParser.java index 59be4ec94c33..c4e458bf9e93 100644 --- a/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/unmarshall/XmlDomParser.java +++ b/core/protocols/aws-query-protocol/src/main/java/software/amazon/awssdk/protocols/query/unmarshall/XmlDomParser.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.protocols.query.unmarshall; +import java.io.IOException; import java.io.InputStream; import java.util.HashMap; import java.util.Iterator; @@ -27,6 +28,7 @@ import javax.xml.stream.events.XMLEvent; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.utils.LookaheadInputStream; /** * Parses an XML document into a simple DOM like structure, {@link XmlElement}. @@ -40,15 +42,20 @@ private XmlDomParser() { } public static XmlElement parse(InputStream inputStream) { + LookaheadInputStream stream = new LookaheadInputStream(inputStream); try { - XMLEventReader reader = FACTORY.get().createXMLEventReader(inputStream); + if (stream.peek() == -1) { + return XmlElement.empty(); + } + + XMLEventReader reader = FACTORY.get().createXMLEventReader(stream); XMLEvent nextEvent; // Skip ahead to the first start element do { nextEvent = reader.nextEvent(); } while (reader.hasNext() && !nextEvent.isStartElement()); return parseElement(nextEvent.asStartElement(), reader); - } catch (XMLStreamException e) { + } catch (IOException | XMLStreamException e) { throw SdkClientException.create("Could not parse XML response.", e); } } diff --git a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlProtocolUnmarshaller.java b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlProtocolUnmarshaller.java index b29521b58853..a975672587cf 100644 --- a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlProtocolUnmarshaller.java +++ b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlProtocolUnmarshaller.java @@ -85,7 +85,13 @@ SdkPojo unmarshall(XmlUnmarshallerContext context, SdkPojo sdkPojo, XmlElement r if (root != null && field.location() == MarshallLocation.PAYLOAD) { if (!context.response().content().isPresent()) { - // This is a payload field, but the service sent no content. Do not populate this field (leave it null). + // This is a payload field, but the service sent no content. Do not populate this field (leave it null or + // empty). + if (field.marshallingType() == MarshallingType.SDK_BYTES && field.containsTrait(PayloadTrait.class)) { + // SDK bytes bound directly to the payload field should never be left empty + field.set(sdkPojo, SdkBytes.fromByteArrayUnsafe(new byte[0])); + } + continue; } @@ -98,9 +104,8 @@ SdkPojo unmarshall(XmlUnmarshallerContext context, SdkPojo sdkPojo, XmlElement r root.getElementsByName(field.unmarshallLocationName()); if (!CollectionUtils.isNullOrEmpty(element)) { - boolean isFieldBlobTypePayload = payloadMemberAsBlobType.isPresent() - && payloadMemberAsBlobType.get().equals(field); - + boolean isFieldBlobTypePayload = payloadMemberAsBlobType.isPresent() && + payloadMemberAsBlobType.get().equals(field); if (isFieldBlobTypePayload) { field.set(sdkPojo, SdkBytes.fromInputStream(context.response().content().get())); } else { diff --git a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlResponseParserUtils.java b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlResponseParserUtils.java index 20eacb700656..c5d8088ec8cb 100644 --- a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlResponseParserUtils.java +++ b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/unmarshall/XmlResponseParserUtils.java @@ -15,9 +15,7 @@ package software.amazon.awssdk.protocols.xml.internal.unmarshall; -import java.io.BufferedInputStream; import java.io.IOException; -import java.io.InputStream; import java.io.UncheckedIOException; import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -30,6 +28,7 @@ import software.amazon.awssdk.http.SdkHttpFullResponse; import software.amazon.awssdk.protocols.query.unmarshall.XmlDomParser; import software.amazon.awssdk.protocols.query.unmarshall.XmlElement; +import software.amazon.awssdk.utils.LookaheadInputStream; /** * Static methods to assist with parsing the response of AWS XML requests. @@ -57,12 +56,10 @@ public static XmlElement parse(SdkPojo sdkPojo, SdkHttpFullResponse response) { } // Make sure there is content in the stream before passing it to the parser. - InputStream content = ensureMarkSupported(responseContent.get()); - content.mark(2); - if (content.read() == -1) { + LookaheadInputStream content = new LookaheadInputStream(responseContent.get()); + if (content.peek() == -1) { return XmlElement.empty(); } - content.reset(); return XmlDomParser.parse(content); } catch (IOException e) { @@ -75,14 +72,6 @@ public static XmlElement parse(SdkPojo sdkPojo, SdkHttpFullResponse response) { } } - private static InputStream ensureMarkSupported(AbortableInputStream content) { - if (content.markSupported()) { - return content; - } - - return new BufferedInputStream(content); - } - /** * Gets the Member which is a Payload and which is of Blob Type. * @param sdkPojo diff --git a/core/regions/pom.xml b/core/regions/pom.xml index a9584c9f035e..e96087ffb342 100644 --- a/core/regions/pom.xml +++ b/core/regions/pom.xml @@ -75,7 +75,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/core/sdk-core/pom.xml b/core/sdk-core/pom.xml index abb32a66cbe4..abbdbc7a8814 100644 --- a/core/sdk-core/pom.xml +++ b/core/sdk-core/pom.xml @@ -83,7 +83,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/http-clients/apache-client/pom.xml b/http-clients/apache-client/pom.xml index 1707ae0c82eb..602e69159b6b 100644 --- a/http-clients/apache-client/pom.xml +++ b/http-clients/apache-client/pom.xml @@ -83,6 +83,7 @@ hamcrest-all test + com.github.tomakehurst wiremock diff --git a/http-clients/aws-crt-client/pom.xml b/http-clients/aws-crt-client/pom.xml index cfefae45b578..b308fb0456d7 100644 --- a/http-clients/aws-crt-client/pom.xml +++ b/http-clients/aws-crt-client/pom.xml @@ -73,6 +73,7 @@ + com.github.tomakehurst wiremock diff --git a/http-clients/netty-nio-client/pom.xml b/http-clients/netty-nio-client/pom.xml index 865b16658c8d..89874f2589f8 100644 --- a/http-clients/netty-nio-client/pom.xml +++ b/http-clients/netty-nio-client/pom.xml @@ -107,6 +107,7 @@ ${awsjavasdk.version} test + com.github.tomakehurst wiremock diff --git a/metric-publishers/pom.xml b/metric-publishers/pom.xml index e06aa1d6d833..93d08967f599 100644 --- a/metric-publishers/pom.xml +++ b/metric-publishers/pom.xml @@ -98,7 +98,7 @@ test - wiremock + wiremock-jre8 com.github.tomakehurst test diff --git a/pom.xml b/pom.xml index 77131693757b..fa94547cef0b 100644 --- a/pom.xml +++ b/pom.xml @@ -95,7 +95,7 @@ 2.13.1 1.0.1 3.12.0 - 2.18.0 + 2.32.0 1.7.30 2.17.1 2.5 diff --git a/services-custom/dynamodb-enhanced/pom.xml b/services-custom/dynamodb-enhanced/pom.xml index 54d3257c084a..be2cee661fee 100644 --- a/services-custom/dynamodb-enhanced/pom.xml +++ b/services-custom/dynamodb-enhanced/pom.xml @@ -175,7 +175,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/services/pom.xml b/services/pom.xml index ab9450d6d04a..0a96572a7ce5 100644 --- a/services/pom.xml +++ b/services/pom.xml @@ -447,7 +447,7 @@ test - wiremock + wiremock-jre8 com.github.tomakehurst test diff --git a/services/s3/pom.xml b/services/s3/pom.xml index 903f6f536626..bc73fe9c08ba 100644 --- a/services/s3/pom.xml +++ b/services/s3/pom.xml @@ -112,7 +112,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/test/auth-sts-testing/pom.xml b/test/auth-sts-testing/pom.xml index 90676a9a8817..85fc98c3fae6 100644 --- a/test/auth-sts-testing/pom.xml +++ b/test/auth-sts-testing/pom.xml @@ -83,7 +83,12 @@ com.github.tomakehurst - wiremock + wiremock-jre8 + test + + + junit + junit test diff --git a/test/codegen-generated-classes-test/pom.xml b/test/codegen-generated-classes-test/pom.xml index 423609388c5c..11c61908dfee 100644 --- a/test/codegen-generated-classes-test/pom.xml +++ b/test/codegen-generated-classes-test/pom.xml @@ -129,7 +129,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/BlobUnmarshallingTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/BlobUnmarshallingTest.java new file mode 100644 index 000000000000..55c2e6a53644 --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/BlobUnmarshallingTest.java @@ -0,0 +1,249 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.protocolrestjson; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.protocolquery.ProtocolQueryAsyncClient; +import software.amazon.awssdk.services.protocolquery.ProtocolQueryClient; +import software.amazon.awssdk.services.protocolrestxml.ProtocolRestXmlAsyncClient; +import software.amazon.awssdk.services.protocolrestxml.ProtocolRestXmlClient; + +/** + * This verifies that blob types are unmarshalled correctly depending on where they exist. Specifically, we currently unmarshall + * SdkBytes fields bound to the payload as empty if the service responds with no data and SdkBytes fields bound to a field as + * null if the service does not specify that field. + */ +@WireMockTest +public class BlobUnmarshallingTest { + private static List testParameters() { + List testCases = new ArrayList<>(); + for (ClientType clientType : ClientType.values()) { + for (Protocol protocol : Protocol.values()) { + for (SdkBytesLocation value : SdkBytesLocation.values()) { + for (ContentLength contentLength : ContentLength.values()) { + testCases.add(Arguments.arguments(clientType, protocol, value, contentLength)); + } + } + } + } + return testCases; + } + + private enum ClientType { + SYNC, + ASYNC + } + + private enum Protocol { + JSON, + XML, + QUERY + } + + private enum SdkBytesLocation { + PAYLOAD, + FIELD + } + + private enum ContentLength { + ZERO, + CHUNKED_ZERO + } + + @ParameterizedTest + @MethodSource("testParameters") + public void missingSdkBytes_unmarshalledCorrectly(ClientType clientType, + Protocol protocol, + SdkBytesLocation bytesLoc, + ContentLength contentLength, + WireMockRuntimeInfo wm) { + if (contentLength == ContentLength.ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("Content-Length", "0").withBody(""))); + } else if (contentLength == ContentLength.CHUNKED_ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(""))); + } + + SdkBytes serviceResult = callService(wm, clientType, protocol, bytesLoc); + + if (bytesLoc == SdkBytesLocation.PAYLOAD) { + assertThat(serviceResult).isNotNull().isEqualTo(SdkBytes.fromUtf8String("")); + } else if (bytesLoc == SdkBytesLocation.FIELD) { + assertThat(serviceResult).isNull(); + } + } + + @ParameterizedTest + @MethodSource("testParameters") + public void presentSdkBytes_unmarshalledCorrectly(ClientType clientType, + Protocol protocol, + SdkBytesLocation bytesLoc, + ContentLength contentLength, + WireMockRuntimeInfo wm) { + String responsePayload = presentSdkBytesResponse(protocol, bytesLoc); + + if (contentLength == ContentLength.ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200) + .withHeader("Content-Length", Integer.toString(responsePayload.length())) + .withBody(responsePayload))); + } else if (contentLength == ContentLength.CHUNKED_ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(responsePayload))); + } + + assertThat(callService(wm, clientType, protocol, bytesLoc)).isEqualTo(SdkBytes.fromUtf8String("X")); + } + + private String presentSdkBytesResponse(Protocol protocol, SdkBytesLocation bytesLoc) { + switch (bytesLoc) { + case PAYLOAD: return "X"; + case FIELD: + switch (protocol) { + case JSON: return "{\"BlobArg\": \"WA==\"}"; + case XML: return "WA=="; + case QUERY: return "WA=="; + default: throw new UnsupportedOperationException(); + } + default: throw new UnsupportedOperationException(); + } + + } + + private SdkBytes callService(WireMockRuntimeInfo wm, ClientType clientType, Protocol protocol, SdkBytesLocation bytesLoc) { + switch (clientType) { + case SYNC: return syncCallService(wm, protocol, bytesLoc); + case ASYNC: return asyncCallService(wm, protocol, bytesLoc); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes syncCallService(WireMockRuntimeInfo wm, Protocol protocol, SdkBytesLocation bytesLoc) { + switch (protocol) { + case JSON: return syncJsonCallService(wm, bytesLoc); + case XML: return syncXmlCallService(wm, bytesLoc); + case QUERY: return syncQueryCallService(wm, bytesLoc); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes asyncCallService(WireMockRuntimeInfo wm, Protocol protocol, SdkBytesLocation bytesLoc) { + switch (protocol) { + case JSON: return asyncJsonCallService(wm, bytesLoc); + case XML: return asyncXmlCallService(wm, bytesLoc); + case QUERY: return asyncQueryCallService(wm, bytesLoc); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes asyncQueryCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolQueryAsyncClient client = + ProtocolQueryAsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).join().payloadMember(); + case FIELD: return client.allTypes(r -> {}).join().blobArg(); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes asyncXmlCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolRestXmlAsyncClient client = + ProtocolRestXmlAsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).join().payloadMember(); + case FIELD: return client.allTypes(r -> {}).join().blobArg(); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes asyncJsonCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolRestJsonAsyncClient client = + ProtocolRestJsonAsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).join().payloadMember(); + case FIELD: return client.allTypes(r -> {}).join().blobArg(); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes syncQueryCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolQueryClient client = + ProtocolQueryClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).payloadMember(); + case FIELD: return client.allTypes(r -> {}).blobArg(); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes syncXmlCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolRestXmlClient client = + ProtocolRestXmlClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).payloadMember(); + case FIELD: return client.allTypes(r -> {}).blobArg(); + default: throw new UnsupportedOperationException(); + } + } + + private SdkBytes syncJsonCallService(WireMockRuntimeInfo wm, SdkBytesLocation bytesLoc) { + ProtocolRestJsonClient client = + ProtocolRestJsonClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (bytesLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadBlob(r -> {}).payloadMember(); + case FIELD: return client.allTypes(r -> {}).blobArg(); + default: throw new UnsupportedOperationException(); + } + } +} diff --git a/test/http-client-tests/pom.xml b/test/http-client-tests/pom.xml index cfdbbf34c47f..178c22373227 100644 --- a/test/http-client-tests/pom.xml +++ b/test/http-client-tests/pom.xml @@ -83,6 +83,7 @@ mockito-core compile + com.github.tomakehurst wiremock diff --git a/test/protocol-tests-core/pom.xml b/test/protocol-tests-core/pom.xml index 50ed8462d8ea..f153b6213d7c 100644 --- a/test/protocol-tests-core/pom.xml +++ b/test/protocol-tests-core/pom.xml @@ -97,7 +97,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 compile diff --git a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-output.json b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-output.json index 4b6cad3e326f..7dedb7a409f2 100644 --- a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-output.json +++ b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-output.json @@ -117,6 +117,7 @@ }, "then": { "deserializedAs": { + "PayloadMember": "" } } }, diff --git a/test/protocol-tests/pom.xml b/test/protocol-tests/pom.xml index 58fa8fa47f02..6bb1843ae4e7 100644 --- a/test/protocol-tests/pom.xml +++ b/test/protocol-tests/pom.xml @@ -155,7 +155,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/test/stability-tests/pom.xml b/test/stability-tests/pom.xml index 191dbb5432f9..d1ebf7978d1a 100644 --- a/test/stability-tests/pom.xml +++ b/test/stability-tests/pom.xml @@ -182,7 +182,7 @@ com.github.tomakehurst - wiremock + wiremock-jre8 test diff --git a/utils/src/main/java/software/amazon/awssdk/utils/LookaheadInputStream.java b/utils/src/main/java/software/amazon/awssdk/utils/LookaheadInputStream.java new file mode 100644 index 000000000000..da9dd893568d --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/LookaheadInputStream.java @@ -0,0 +1,127 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import software.amazon.awssdk.annotations.SdkProtectedApi; + +/** + * A wrapper for an {@link InputStream} that allows {@link #peek()}ing one byte ahead in the stream. This is useful for + * detecting the end of a stream without actually consuming any data in the process (e.g. so the stream can be passed to + * another library that doesn't handle end-of-stream as the first byte well). + */ +@SdkProtectedApi +public class LookaheadInputStream extends FilterInputStream { + private Integer next; + private Integer nextAtMark; + + public LookaheadInputStream(InputStream in) { + super(in); + } + + public int peek() throws IOException { + if (next == null) { + next = read(); + } + + return next; + } + + @Override + public int read() throws IOException { + if (next == null) { + return super.read(); + } + + Integer next = this.next; + this.next = null; + return next; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (next == null) { + return super.read(b, off, len); + } + + if (len == 0) { + return 0; + } + + if (next == -1) { + return -1; + } + + b[off] = (byte) (int) next; + next = null; + + if (len == 1) { + return 1; + } + + return super.read(b, off + 1, b.length - 1) + 1; + } + + @Override + public long skip(long n) throws IOException { + if (next == null) { + return super.skip(n); + } + + if (n == 0) { + return 0; + } + + if (next == -1) { + return 0; + } + + next = null; + + if (n == 1) { + return 1; + } + + return super.skip(n - 1); + } + + @Override + public int available() throws IOException { + if (next == null) { + return super.available(); + } + + return super.available() + 1; + } + + @Override + public synchronized void mark(int readlimit) { + if (next == null) { + super.mark(readlimit); + } else { + nextAtMark = next; + super.mark(readlimit - 1); + } + } + + @Override + public synchronized void reset() throws IOException { + next = nextAtMark; + super.reset(); + } +}