Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import org.reactivestreams.Publisher;
import software.amazon.awssdk.annotations.SdkProtectedApi;
Expand All @@ -29,8 +28,6 @@
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.internal.interceptor.DefaultFailedExecutionContext;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.utils.Logger;
Expand Down Expand Up @@ -66,9 +63,11 @@ public InterceptorContext modifyRequest(InterceptorContext context, ExecutionAtt
InterceptorContext result = context;
for (ExecutionInterceptor interceptor : interceptors) {
SdkRequest interceptorResult = interceptor.modifyRequest(result, executionAttributes);
validateInterceptorResult(result.request(), interceptorResult, interceptor, "modifyRequest");

result = result.copy(b -> b.request(interceptorResult));
if (interceptorResult != result.request()) {
validateInterceptorResult(result.request(), interceptorResult, interceptor, "modifyRequest");
result = result.copy(b -> b.request(interceptorResult));
}
}
return result;
}
Expand All @@ -88,46 +87,20 @@ public InterceptorContext modifyHttpRequestAndHttpContent(InterceptorContext con
AsyncRequestBody asyncRequestBody = interceptor.modifyAsyncHttpContent(result, executionAttributes).orElse(null);
RequestBody requestBody = interceptor.modifyHttpContent(result, executionAttributes).orElse(null);
SdkHttpRequest interceptorResult = interceptor.modifyHttpRequest(result, executionAttributes);
validateInterceptorResult(result.httpRequest(), interceptorResult, interceptor, "modifyHttpRequest");

InterceptorContext.Builder builder = result.toBuilder();
if (asyncRequestBody != result.asyncRequestBody().orElse(null) ||
requestBody != result.requestBody().orElse(null) ||
interceptorResult != result.httpRequest()) {

applySdkHttpFullRequestHack(result, builder);

result = builder.httpRequest(interceptorResult)
.asyncRequestBody(asyncRequestBody)
.requestBody(requestBody)
.build();
validateInterceptorResult(result.httpRequest(), interceptorResult, interceptor, "modifyHttpRequest");
result = result.copy(r -> r.httpRequest(interceptorResult)
.asyncRequestBody(asyncRequestBody)
.requestBody(requestBody));
}
}
return result;
}

private void applySdkHttpFullRequestHack(InterceptorContext context, InterceptorContext.Builder builder) {
// Someone thought it would be a great idea to allow interceptors to return SdkHttpFullRequest to modify the payload
// instead of using the modifyPayload method. This is for backwards-compatibility with those interceptors.
// TODO: Update interceptors to use the proper payload-modifying method so that this code path is only used for older
// client versions. Maybe if we ever decide to break @SdkProtectedApis (if we stop using Jackson?!) we can even remove
// this hack!
SdkHttpFullRequest sdkHttpFullRequest = (SdkHttpFullRequest) context.httpRequest();

if (context.requestBody().isPresent()) {
return;
}

Optional<ContentStreamProvider> contentStreamProvider = sdkHttpFullRequest.contentStreamProvider();

if (!contentStreamProvider.isPresent()) {
return;
}

long contentLength = Long.parseLong(sdkHttpFullRequest.firstMatchingHeader("Content-Length").orElse("0"));
String contentType = sdkHttpFullRequest.firstMatchingHeader("Content-Type").orElse("");
RequestBody requestBody = RequestBody.fromContentProvider(contentStreamProvider.get(),
contentLength,
contentType);
builder.requestBody(requestBody);
}

public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
interceptors.forEach(i -> i.beforeTransmission(context, executionAttributes));
}
Expand All @@ -143,11 +116,15 @@ public InterceptorContext modifyHttpResponse(InterceptorContext context,
for (int i = interceptors.size() - 1; i >= 0; i--) {
SdkHttpResponse interceptorResult =
interceptors.get(i).modifyHttpResponse(result, executionAttributes);
validateInterceptorResult(result.httpResponse(), interceptorResult, interceptors.get(i), "modifyHttpResponse");

InputStream response = interceptors.get(i).modifyHttpResponseContent(result, executionAttributes).orElse(null);

result = result.toBuilder().httpResponse(interceptorResult).responseBody(response).build();
if (interceptorResult != result.httpResponse() || response != result.responseBody().orElse(null)) {
validateInterceptorResult(result.httpResponse(), interceptorResult, interceptors.get(i), "modifyHttpResponse");
result = result.copy(r -> r.httpResponse(interceptorResult)
.responseBody(response));
}


}

return result;
Expand All @@ -163,9 +140,9 @@ public InterceptorContext modifyAsyncHttpResponse(InterceptorContext context,
Publisher<ByteBuffer> newResponsePublisher =
interceptor.modifyAsyncHttpResponseContent(result, executionAttributes).orElse(null);

result = result.toBuilder()
.responsePublisher(newResponsePublisher)
.build();
if (newResponsePublisher != result.responsePublisher().orElse(null)) {
result = result.copy(r -> r.responsePublisher(newResponsePublisher));
}
}

return result;
Expand All @@ -183,9 +160,11 @@ public InterceptorContext modifyResponse(InterceptorContext context, ExecutionAt
InterceptorContext result = context;
for (int i = interceptors.size() - 1; i >= 0; i--) {
SdkResponse interceptorResult = interceptors.get(i).modifyResponse(result, executionAttributes);
validateInterceptorResult(result.response(), interceptorResult, interceptors.get(i), "modifyResponse");

result = result.copy(b -> b.response(interceptorResult));
if (interceptorResult != result.response()) {
validateInterceptorResult(result.response(), interceptorResult, interceptors.get(i), "modifyResponse");
result = result.copy(b -> b.response(interceptorResult));
}
}

return result;
Expand All @@ -200,8 +179,12 @@ public DefaultFailedExecutionContext modifyException(DefaultFailedExecutionConte
DefaultFailedExecutionContext result = context;
for (int i = interceptors.size() - 1; i >= 0; i--) {
Throwable interceptorResult = interceptors.get(i).modifyException(result, executionAttributes);
validateInterceptorResult(result.exception(), interceptorResult, interceptors.get(i), "modifyException");
result = result.copy(b -> b.exception(interceptorResult));

if (interceptorResult != result.exception()) {
validateInterceptorResult(result.exception(), interceptorResult,
interceptors.get(i), "modifyException");
result = result.copy(b -> b.exception(interceptorResult));
}
}

return result;
Expand Down