Skip to content

Commit a7aba17

Browse files
committed
Content length updation done in the Marshallers after internal comment
1 parent 1f4a2da commit a7aba17

File tree

4 files changed

+152
-122
lines changed

4 files changed

+152
-122
lines changed

core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ void doMarshall(SdkPojo pojo) {
201201
}
202202

203203
private void updateContentLengthHeader(int contentLength) {
204-
if (!request.headers().containsKey(CONTENT_LENGTH)) {
204+
if (!request.firstMatchingHeader(CONTENT_LENGTH).isPresent()) {
205205
request.putHeader(CONTENT_LENGTH, Integer.toString(contentLength));
206206
}
207207
}

core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,11 @@ void doMarshall(SdkPojo pojo) {
9191
Object val = field.getValueOrDefault(pojo);
9292

9393
if (isBinary(field, val)) {
94-
request.contentStreamProvider(((SdkBytes) val)::asInputStream);
94+
SdkBytes sdkBytes = (SdkBytes) val;
95+
request.contentStreamProvider(sdkBytes::asInputStream);
9596
setContentTypeHeaderIfNeeded("binary/octet-stream");
97+
request.putHeader(CONTENT_LENGTH, Integer.toString(sdkBytes.asByteArray().length));
98+
9699

97100
} else if (isExplicitPayloadMember(field) && val instanceof String) {
98101
byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8);

http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/AwsCrtSyncWireMockTest.java

Lines changed: 0 additions & 120 deletions
This file was deleted.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.protocol.tests.contentlength;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
22+
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
23+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
24+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
25+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
26+
import static software.amazon.awssdk.http.Header.CONTENT_LENGTH;
27+
28+
import com.github.tomakehurst.wiremock.junit.WireMockRule;
29+
import java.net.URI;
30+
import java.nio.charset.StandardCharsets;
31+
import org.junit.Rule;
32+
import org.junit.Test;
33+
import software.amazon.awssdk.core.SdkBytes;
34+
import software.amazon.awssdk.core.interceptor.Context;
35+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
36+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
37+
import software.amazon.awssdk.http.SdkHttpRequest;
38+
import software.amazon.awssdk.http.crt.AwsCrtHttpClient;
39+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
40+
import software.amazon.awssdk.services.protocolrestjson.model.OperationWithExplicitPayloadStructureResponse;
41+
import software.amazon.awssdk.services.protocolrestjson.model.SimpleStruct;
42+
import software.amazon.awssdk.services.protocolrestxml.ProtocolRestXmlClient;
43+
import software.amazon.awssdk.services.protocolrestxml.model.OperationWithExplicitPayloadStringResponse;
44+
45+
public class MarshallersAddContentLengthTest {
46+
public static final String STRING_PAYLOAD = "TEST_STRING_PAYLOAD";
47+
@Rule
48+
public WireMockRule wireMock = new WireMockRule(0);
49+
50+
@Test
51+
public void jsonMarshallers_AddContentLength_for_explicitBinaryPayload() {
52+
stubSuccessfulResponse();
53+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
54+
ProtocolRestJsonClient client = ProtocolRestJsonClient.builder()
55+
.httpClient(AwsCrtHttpClient.builder().build())
56+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
57+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
58+
.build();
59+
60+
client.operationWithExplicitPayloadBlob(p -> p.payloadMember(SdkBytes.fromString(STRING_PAYLOAD,
61+
StandardCharsets.UTF_8)));
62+
verify(postRequestedFor(anyUrl()).withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length()))));
63+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH).get())
64+
.isEqualTo(String.valueOf(STRING_PAYLOAD.length()));
65+
66+
}
67+
68+
@Test
69+
public void jsonMarshallers_AddContentLength_for_explicitStringPayload() {
70+
stubSuccessfulResponse();
71+
String expectedPayload = String.format("{\"StringMember\":\"%s\"}", STRING_PAYLOAD);
72+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
73+
ProtocolRestJsonClient client = ProtocolRestJsonClient.builder()
74+
.httpClient(AwsCrtHttpClient.builder().build())
75+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
76+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
77+
.build();
78+
79+
OperationWithExplicitPayloadStructureResponse response =
80+
client.operationWithExplicitPayloadStructure(p -> p.payloadMember(SimpleStruct.builder().stringMember(STRING_PAYLOAD).build()));
81+
verify(postRequestedFor(anyUrl())
82+
.withRequestBody(equalTo(expectedPayload))
83+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length()))));
84+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH).get())
85+
.isEqualTo(String.valueOf(expectedPayload.length()));
86+
87+
}
88+
89+
@Test
90+
public void xmlMarshallers_AddContentLength_for_explicitBinaryPayload() {
91+
stubSuccessfulResponse();
92+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
93+
ProtocolRestXmlClient client = ProtocolRestXmlClient.builder()
94+
.httpClient(AwsCrtHttpClient.builder().build())
95+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
96+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
97+
.build();
98+
client.operationWithExplicitPayloadBlob(r -> r.payloadMember(SdkBytes.fromString(STRING_PAYLOAD,
99+
StandardCharsets.UTF_8)));
100+
101+
verify(postRequestedFor(anyUrl())
102+
.withRequestBody(equalTo(STRING_PAYLOAD))
103+
104+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length()))));
105+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH).get())
106+
.isEqualTo(String.valueOf(
107+
STRING_PAYLOAD.length()));
108+
}
109+
110+
@Test
111+
public void xmlMarshallers_AddContentLength_for_explicitStringPayload() {
112+
stubSuccessfulResponse();
113+
String expectedPayload = STRING_PAYLOAD;
114+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
115+
ProtocolRestXmlClient client = ProtocolRestXmlClient.builder()
116+
.httpClient(AwsCrtHttpClient.builder().build())
117+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
118+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
119+
.build();
120+
OperationWithExplicitPayloadStringResponse stringResponse =
121+
client.operationWithExplicitPayloadString(p -> p.payloadMember(STRING_PAYLOAD));
122+
verify(postRequestedFor(anyUrl())
123+
.withRequestBody(equalTo(expectedPayload))
124+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length()))));
125+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH).get())
126+
.isEqualTo(String.valueOf(expectedPayload.length()));
127+
128+
}
129+
130+
private void stubSuccessfulResponse() {
131+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200)));
132+
}
133+
134+
private static class CaptureRequestInterceptor implements ExecutionInterceptor {
135+
private SdkHttpRequest requestAfterMarshilling;
136+
137+
public SdkHttpRequest requestAfterMarshalling() {
138+
return requestAfterMarshilling;
139+
}
140+
141+
@Override
142+
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
143+
this.requestAfterMarshilling = context.httpRequest();
144+
}
145+
146+
}
147+
}

0 commit comments

Comments
 (0)