diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-4c86865.json b/.changes/next-release/bugfix-AWSSDKforJavav2-4c86865.json new file mode 100644 index 000000000000..9ba952df8138 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-4c86865.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Add content-length header in Json and Xml Protocol Marshaller for String and Binary explicit Payloads." +} diff --git a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java index b81110bd1026..92f34920d2c8 100644 --- a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java +++ b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java @@ -181,12 +181,16 @@ void doMarshall(SdkPojo pojo) { Object val = field.getValueOrDefault(pojo); if (isExplicitBinaryPayload(field)) { if (val != null) { - request.contentStreamProvider(((SdkBytes) val)::asInputStream); + SdkBytes sdkBytes = (SdkBytes) val; + request.contentStreamProvider(sdkBytes::asInputStream); + updateContentLengthHeader(sdkBytes.asByteArrayUnsafe().length); } } else if (isExplicitStringPayload(field)) { if (val != null) { byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8); request.contentStreamProvider(() -> new ByteArrayInputStream(content)); + updateContentLengthHeader(content.length); + } } else if (isExplicitPayloadMember(field)) { marshallExplicitJsonPayload(field, val); @@ -196,6 +200,10 @@ void doMarshall(SdkPojo pojo) { } } + private void updateContentLengthHeader(int contentLength) { + request.putHeader(CONTENT_LENGTH, Integer.toString(contentLength)); + } + private boolean isExplicitBinaryPayload(SdkField field) { return isExplicitPayloadMember(field) && MarshallingType.SDK_BYTES.equals(field.marshallingType()); } diff --git a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java index c8f392251d65..2f6f6bb89eb7 100644 --- a/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java +++ b/core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java @@ -91,8 +91,10 @@ void doMarshall(SdkPojo pojo) { Object val = field.getValueOrDefault(pojo); if (isBinary(field, val)) { - request.contentStreamProvider(((SdkBytes) val)::asInputStream); + SdkBytes sdkBytes = (SdkBytes) val; + request.contentStreamProvider(sdkBytes::asInputStream); setContentTypeHeaderIfNeeded("binary/octet-stream"); + request.putHeader(CONTENT_LENGTH, Integer.toString(sdkBytes.asByteArrayUnsafe().length)); } else if (isExplicitPayloadMember(field) && val instanceof String) { byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8); diff --git a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-input.json b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-input.json index 4c58440beb23..8528b322c376 100644 --- a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-input.json +++ b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-input.json @@ -442,6 +442,11 @@ }, "then": { "serializedAs": { + "headers": { + "contains": { + "content-length": "8" + } + }, "body": { "equals": "contents" } diff --git a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-input.json b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-input.json index fdf6882d7306..de141ecdbfa3 100644 --- a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-input.json +++ b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-input.json @@ -52,6 +52,11 @@ }, "then": { "serializedAs": { + "headers": { + "contains": { + "Content-length": "22" + } + }, "body": { "jsonEquals": "{\"StringMember\": \"foo\"}" } diff --git a/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/contentlength/MarshallersAddContentLengthTest.java b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/contentlength/MarshallersAddContentLengthTest.java new file mode 100644 index 000000000000..8e8b44208ee3 --- /dev/null +++ b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/contentlength/MarshallersAddContentLengthTest.java @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.protocol.tests.contentlength; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.crt.AwsCrtHttpClient; +import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient; +import software.amazon.awssdk.services.protocolrestjson.model.OperationWithExplicitPayloadStructureResponse; +import software.amazon.awssdk.services.protocolrestjson.model.SimpleStruct; +import software.amazon.awssdk.services.protocolrestxml.ProtocolRestXmlClient; +import software.amazon.awssdk.services.protocolrestxml.model.OperationWithExplicitPayloadStringResponse; + +@WireMockTest +public class MarshallersAddContentLengthTest { + public static final String STRING_PAYLOAD = "TEST_STRING_PAYLOAD"; + + @Test + void jsonMarshallers_AddContentLength_for_explicitBinaryPayload(WireMockRuntimeInfo wireMock) { + stubSuccessfulResponse(); + CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor(); + ProtocolRestJsonClient client = ProtocolRestJsonClient.builder() + .httpClient(AwsCrtHttpClient.builder().build()) + .overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor)) + .endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort())) + .build(); + client.operationWithExplicitPayloadBlob(p -> p.payloadMember(SdkBytes.fromString(STRING_PAYLOAD, + StandardCharsets.UTF_8))); + verify(postRequestedFor(anyUrl()).withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length())))); + assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH)) + .contains(String.valueOf(STRING_PAYLOAD.length())); + } + + @Test + void jsonMarshallers_AddContentLength_for_explicitStringPayload(WireMockRuntimeInfo wireMock) { + stubSuccessfulResponse(); + String expectedPayload = String.format("{\"StringMember\":\"%s\"}", STRING_PAYLOAD); + CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor(); + ProtocolRestJsonClient client = ProtocolRestJsonClient.builder() + .httpClient(AwsCrtHttpClient.builder().build()) + .overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor)) + .endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort())) + .build(); + OperationWithExplicitPayloadStructureResponse response = + client.operationWithExplicitPayloadStructure(p -> p.payloadMember(SimpleStruct.builder().stringMember(STRING_PAYLOAD).build())); + verify(postRequestedFor(anyUrl()) + .withRequestBody(equalTo(expectedPayload)) + .withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length())))); + assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH)) + .contains(String.valueOf(expectedPayload.length())); + } + + @Test + void xmlMarshallers_AddContentLength_for_explicitBinaryPayload(WireMockRuntimeInfo wireMock) { + stubSuccessfulResponse(); + CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor(); + ProtocolRestXmlClient client = ProtocolRestXmlClient.builder() + .httpClient(AwsCrtHttpClient.builder().build()) + .overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor)) + .endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort())) + .build(); + client.operationWithExplicitPayloadBlob(r -> r.payloadMember(SdkBytes.fromString(STRING_PAYLOAD, + StandardCharsets.UTF_8))); + verify(postRequestedFor(anyUrl()).withRequestBody(equalTo(STRING_PAYLOAD)) + .withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length())))); + assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH)) + .contains(String.valueOf(STRING_PAYLOAD.length())); + } + + @Test + void xmlMarshallers_AddContentLength_for_explicitStringPayload(WireMockRuntimeInfo wireMock) { + stubSuccessfulResponse(); + String expectedPayload = STRING_PAYLOAD; + CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor(); + ProtocolRestXmlClient client = ProtocolRestXmlClient.builder() + .httpClient(AwsCrtHttpClient.builder().build()) + .overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor)) + .endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort())) + .build(); + OperationWithExplicitPayloadStringResponse stringResponse = + client.operationWithExplicitPayloadString(p -> p.payloadMember(STRING_PAYLOAD)); + verify(postRequestedFor(anyUrl()) + .withRequestBody(equalTo(expectedPayload)) + .withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length())))); + assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH)) + .contains(String.valueOf(expectedPayload.length())); + } + + private void stubSuccessfulResponse() { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200))); + } + + private static class CaptureRequestInterceptor implements ExecutionInterceptor { + private SdkHttpRequest requestAfterMarshilling; + + public SdkHttpRequest requestAfterMarshalling() { + return requestAfterMarshilling; + } + + @Override + public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) { + this.requestAfterMarshilling = context.httpRequest(); + } + } +}